pytorch搭建训练自己数据集的模型(预处理、读取自己的图片、进行训练和测试、保存模型、加载模型和测试)

第一阶段:读取图片并保存为.txt

import os
import random

#把训练集和测试集分为8:2
train_ratio = 0.8
test_ratio = 1 - train_ratio

rootdata = '/home/hsy/PycharmProjects/数据集/5月下旬'

train_list, test_list = [], []

data_list = []

#图片的标签
class_flag = -1

'''
要取得该文件夹下的所有文件,可以使用 for(root,dirs,files) in walk(roots)函数
roots:代表需要便利的根文件夹
root: 表示正在遍历的文件夹的名字
dirs:记录正在遍历的文件夹中的文件
'''
for root, dirs, files in os.walk(rootdata):

    for i in range(len(files)):
        '''
        os.path.join()函数:连接两个或者更多的路径名组价你
        1.如果各组件首字母不包含'/',则函数会自动加上
        2.如果一个组件是一个绝对路径,则在它之前的所有组件均会被舍弃
        3.如果最后一个组件为空,则成一个路径以一个'/'分隔符结尾

        root='/home/hsy/PycharmProjects/数据集/5月下旬/train/鱼腥草'
        files[i]='yuxingcao_1.jpg'

        os.path.join(root,files[i])='/home/hsy/PycharmProjects/数据集/5月下旬/train/鱼腥草/yuxingcao_1.jpg'
        '''

        data_list.append(os.path.join(root, files[i]))

    for i in range(0, int(len(files) * train_ratio)):
        train_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
        train_list.append(train_data)

    for i in range(int(len(files) * train_ratio), len(files)):
        test_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
        test_list.append(test_data)

    class_flag += 1

# print(train_list)

# 将数据打乱
random.shuffle(train_list)
random.shuffle(test_list)


# 保存到txt
with open('../data/train.txt', 'w', encoding='UTF-8') as f:
    for train in train_list:
        f.write(train)

with open('../data/test.txt', 'w', encoding='UTF-8') as f:
    for test in test_list:
        f.write(test)


print(test_list)

在这里插入图片描述
train.txt

/home/hsy/PycharmProjects/数据集/5月下旬/瞿麦/qumai_109.jpg	16
/home/hsy/PycharmProjects/数据集/5月下旬/洋金花/yangjinhua_33.jpg	4
/home/hsy/PycharmProjects/数据集/5月下旬/萱草/xuancao_1.jpg	19
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_137.jpg	12
/home/hsy/PycharmProjects/数据集/5月下旬/香加皮/xiangjiapi_50.jpg	17
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_117.jpg	12
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_136.jpg	12
/home/hsy/PycharmProjects/数据集/5月下旬/洋金花/yangjinhua_28.jpg	4
/home/hsy/PycharmProjects/数据集/5月下旬/金芥麦/jinjiemai_107.jpg	6
/home/hsy/PycharmProjects/数据集/5月下旬/何首乌/heshouwu_42.jpg	3
	.......

test.txt

/home/hsy/PycharmProjects/数据集/5月下旬/垂盆草/chuipencao_7.jpg	18
/home/hsy/PycharmProjects/数据集/5月下旬/夏枯草/xiakucao_124.jpg	2
/home/hsy/PycharmProjects/数据集/5月下旬/车前草/cheqiancao_106.jpg	8
/home/hsy/PycharmProjects/数据集/5月下旬/京大戟/jingdaji_39.jpg	7
/home/hsy/PycharmProjects/数据集/5月下旬/射干/shegan_76.jpg	5
/home/hsy/PycharmProjects/数据集/5月下旬/夏枯草/xiakucao_151.jpg	2
/home/hsy/PycharmProjects/数据集/5月下旬/牛蒡子/niubangzi_184.jpg	1
/home/hsy/PycharmProjects/数据集/5月下旬/决明子/juemingzi_100.jpg	10
/home/hsy/PycharmProjects/数据集/5月下旬/瞿麦/qumai_23.jpg	16
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_105.jpg	12
/home/hsy/PycharmProjects/数据集/5月下旬/决明子/juemingzi_92.jpg	10
/home/hsy/PycharmProjects/数据集/5月下旬/鱼腥草/yuxingcao_45.jpg	0
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_24.jpg	12
/home/hsy/PycharmProjects/数据集/5月下旬/金芥麦/jinjiemai_98.jpg	6
.......

第二阶段:改写Dataset,保证下阶段读取自己的数据集

import torch
from  PIL import  Image
import os
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
from torchvision import  transforms,utils,datasets
import numpy as np


#图像标准化
# transform_BN=transforms.Normalize((0.485,0.456,0.406),(0.226,0.224,0.225))


class LoadData(Dataset):
    def __init__(self,txt_path,train_flag=True):
        self.imgs_info=self.get_imags(txt_path)
        self.train_flag=train_flag

        self.transform_train=transforms.Compose([

            # #随机水平翻转
            # transforms.RandomHorizontalFlip(),
            # #随机垂直翻转
            # transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])

        ])

        self.transform_test=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
        ])

    def get_imags(self, txt_path):
        with open(txt_path,'r',encoding='UTF-8') as f:
            imgs_info=f.readlines()
            imgs_info=list(map(lambda x:x.strip().split('\t'),imgs_info))

        return imgs_info


    def __getitem__(self, index):
        img_path,label=self.imgs_info[index]

        img=Image.open(img_path)

        img=img.convert("RGB")

        if self.train_flag:
            img=self.transform_train(img)
        else:
            img=self.transform_test(img)

        label=int(label)

        #返回打开的图片和它的标签
        return img,label

    def __len__(self):
        return len(self.imgs_info)

第三阶段:读取自己的数据集并训练和测试

from torch import optim
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import time

from data.CreateDataloader import LoadData

def load_dataset(batch_size):

    train_set=LoadData("../data/train.txt",True)
    test_set=LoadData("../data/test.txt",False)

    train_iter=torch.utils.data.DataLoader(
        dataset=train_set,batch_size=batch_size,shuffle=True,num_workers=4
    )

    test_iter=torch.utils.data.DataLoader(
        dataset=test_set,batch_size=batch_size,shuffle=True,num_workers=4
    )

    return train_iter,test_iter

def get_cur_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def learning_curve(record_train,record_test=None):
    plt.style.use('ggplot')

    plt.plot(range(1,len(record_train)+1),record_train,label='train acc')
    if record_test is not None:
        plt.plot(range(1,len(record_test)+1),record_test,label="test acc")

    plt.legend(loc=4)
    plt.title("learning curve")
    plt.xticks(range(0,len(record_train)+1,5))
    plt.yticks(range(0,101,5))
    plt.xlabel("epoch")
    plt.ylabel("accuracy")

    plt.show()

'''
model.train()
在使用pytorch构建神经网络的时候,训练过程中会在程序上方添加一句model.train()
作用是启动batch.normalize和dropout

model.eval()
测试过程中会使用model.eval(),这时神经网络会沿用batch normalization的值,并不使用dropou
'''
def train(model,train_iter,criterion,optimizer,device,num_print,lr_scheduler=None):

    model.train()


    total,correct,train_loss=0,0,0
    start=time.time()

    for i,(inputs,labels) in enumerate(train_iter):
        inputs,labels=inputs.to(device),labels.to(device)

        output=model(inputs)
        # print(inputs.shape)
        loss=criterion(output,labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss+=loss.item()
        total+=labels.size(0)
        correct+=torch.eq(output.argmax(dim=1),labels).sum().item()

        train_acc=100*correct/total
        # print(train_acc)


        if (i + 1) % num_print == 0:
            print("step: [{}/{}], train_loss: {:.3f} | train_acc: {:6.3f}% | lr: {:.6f}" \
                .format(i + 1, len(train_iter), train_loss / (i + 1), \
                train_acc, get_cur_lr(optimizer)))


    if lr_scheduler is not None:
        lr_scheduler.step()

    print("-----cost time:{:.4f}s----".format(time.time()-start))

    # if test_iter is not None:
    #     record_test.append(test(model,test_iter,criterion,device))


    return train_acc


def test(model, test_iter, criterion, device,test_num):
    j=0
    total,correct=0,0
    caoyao_list = ['鱼腥草', '牛蒡子', '夏枯草', '何首乌', '洋金花', '射干', '金芥麦', '京大戟', '车前草', '千金子',
                   '决明子', '红花', '紫苏', '白勺', '薄荷', '当归', '瞿麦', '香加皮', '垂盆草', '萱草'
                   ]

    model.eval()


    with torch.no_grad():
        print("*************************test***************************")

        for inputs,labels in test_iter:
            inputs,labels=inputs.to(device),labels.to(device)

            output=model(inputs)
            loss=criterion(output,labels)

            total+=labels.size(0)
            # print("labels.shape",labels.shape,labels.size(0))
            correct+=torch.eq(output.argmax(dim=1),labels).sum().item()


    test_acc=100.0*correct/total
    print("test_loss:{:.3} | test_acc:{:6.3f}%"\
          .format(loss.item(),test_acc)
          )


    print("*************************************************************")
    # model.train()


    return  test_acc


from model.VggNet import *
from model.VGG11 import *
from model.ResNet18 import *

batch_size=14
num_epochs=30
num_class=20
learning_rate=0.001
momentum=0.9
weight_decay=0.0005
num_print=40
test_num=0
device="cuda" if torch.cuda.is_available() else "cpu"
def main():
	#这里需要更改为自己的网络模型
    model=RestNet18_Net().to(device)

    train_iter,test_iter=load_dataset(batch_size)

    criterion=nn.CrossEntropyLoss()
    optimizer=optim.SGD(
        model.parameters(),
        lr=learning_rate,
        momentum=momentum,
        weight_decay=weight_decay,
        nesterov=True

    )

    lr_scheduler=optim.lr_scheduler.StepLR(optimizer,step_size=8,gamma=0.1)

    train_acc=list()
    test_acc=list()
    test_num=0
    for epoch in range(num_epochs):
        test_num+=1
        print('=================epoch:[{}/{}]======================'.format(epoch+1,num_epochs))
        record_train=train(model,train_iter,criterion,optimizer,device,num_print,lr_scheduler)
        record_test=test(model,test_iter,criterion,device,test_num)

        train_acc.append(record_train)
        test_acc.append(record_test)


    print("Finished Training")
	#保存训练好的模型
    torch.save(model, '../save_model/ResNet18/1.pth')
    torch.save(model.state_dict(), '../save_model/ResNet18/1_params.pth')
    
    learning_curve(train_acc,test_acc)


if __name__ == '__main__':
    main()

如果这段代码看不懂可以看:https://blog.csdn.net/m0_50127633/article/details/117045008,在这里我有比较详细的注释。

第四阶段:模型加载并进行测试

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image




def pridict():

    device="cuda" if torch.cuda.is_available() else "cpu"

    path='../save_model/ResNet18/1.pth'

    model = torch.load(path)
    model=model.to(device)

    model.eval()

    img=Image.open('/home/hsy/PycharmProjects/数据集/5月下旬/当归/danggui_49.jpg')
    transform = transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.226, 0.224, 0.225])
                                    ])

    img = img.convert("RGB")  # 如果是标准的RGB格式,则可以不加
    img = transform(img)
    img = img.unsqueeze(0)
    img = img.to(device)

    with torch.no_grad():
        py = model(img)
    '''
    torch.max()这个函数返回的是两个值,第一个值是具体的value(我们用下划线_表示),第二个值是value所在的index
    下划线_ 表示的就是具体的value,也就是输出的最大值。
    数字1其实可以写为dim=1,这里简写为1,python也可以自动识别,dim=1表示输出所在行的最大值
    '''
    _,predicted = torch.max(py, 1)  # 获取分类结果
    #预测结果的标签
    classIndex = predicted.item()


    print"预测结果",classIndex)


if __name__ == '__main__':
    pridict()

在这里插入图片描述
这是根据我自己的数据集进行写的,如果你要训练自己数据的话需要进行改写,欢迎指出不足。

  • 9
    点赞
  • 153
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
你可以按照以下步骤来实现CNN手写数字别的训练保存模型以及加载模型进行测试。 1. 导入所需的库: ```python import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms import torch.optim as optim import pandas as pd ``` 2. 创建一个自定义的CNN模型: ```python class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2) self.relu1 = nn.ReLU() self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2) self.relu2 = nn.ReLU() self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.fc = nn.Linear(7*7*32, 10) def forward(self, x): out = self.conv1(x) out = self.relu1(out) out = self.maxpool1(out) out = self.conv2(out) out = self.relu2(out) out = self.maxpool2(out) out = out.view(out.size(0), -1) out = self.fc(out) return out model = CNN() ``` 3. 加载训练集数据并进行预处理: ```python transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) ``` 4. 定义损失函数和优化器: ```python criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) ``` 5. 进行模型训练: ```python total_epochs = 5 for epoch in range(total_epochs): running_loss = 0.0 for i, (inputs, labels) in enumerate(train_loader): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if (i+1) % 100 == 0: print(f'Epoch [{epoch+1}/{total_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}') running_loss = 0.0 print('Training finished!') ``` 6. 保存训练好的模型: ```python torch.save(model.state_dict(), 'model.pth') print('Model saved!') ``` 7. 在另一个文件中加载保存模型进行测试: ```python # 加载模型 model = CNN() model.load_state_dict(torch.load('model.pth')) model.eval() # 加载测试集数据 test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) # 测试模型 correct = 0 total = 0 with torch.no_grad(): for inputs, labels in test_loader: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f'Test Accuracy: {accuracy}%') ``` 以上是一个简单的示例,展示了如何使用PyTorch构建、训练保存CNN模型,并在另一个文件中加载模型进行测试。你可以根据自己的需求进行修改和扩展。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值