模型训练系统

模 型 训 练 系 统 模型训练系统


GUI界面采取读取文件的形式,和后端模型训练进行分离


写XML文件

https://blog.csdn.net/qq_41375318/article/details/112883753

<?xml version="1.0" ?>
<train_config>
   <param>
      <epoch>100</epoch>
      <batchsize>2</batchsize>
      <dataset_path>F:\PycharmWorkPlace\ModelTrainingSystem\api\classification\cifar10_dataset.txt</dataset_path>
   </param>
</train_config>

Train.py

1.读取xml文件
2.配置各项参数
3.训练
# ==  引入工具包  ==

# == step 0 参数配置 ==

# == step 1 数据处理 ==

# == step 2 模型 ==

# == step 3 损失函数 ==

# == step 4 优化器 ==

# == step 5 评测函数==

# == step 6 训练 ==

# == step 7 训练可视化 ==

# == inference ==


# ==  引入工具包  ==
from torch.utils.data import Dataset
from PIL import Image
import os
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import xml.etree.ElementTree as ET # 导入ElementTree模块
import torch.optim as optim
import torch.nn as nn
import torch

# == step 0 参数配置 ==
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
norm_mean = [0.33424968, 0.33424437, 0.33428448]
norm_std = [0.24796878, 0.24796101, 0.24801227]
epoch = None
batchsize = None
train_dataset_path = None
valid_dataset_path = None
learning_rate =None
path_saved_model = 'best_model.pth'

# 解析xml配置
tree = ET.parse('train_config.xml') # 获取解析对象
root = tree.getroot()  # 获取根节点

# 赋值
for node in root.iter('epoch'):  # 在 根节点的子节点中过滤出标签‘epoch’
    epoch = int(node.text)
for node in root.iter('batchsize'):  # 在 根节点的子节点中过滤出标签‘batchsize’
    batchsize = int(node.text)
for node in root.iter('train_dataset_path'):  # 在 根节点的子节点中过滤出标签‘dataset_path’
    train_dataset_path = node.text
for node in root.iter('valid_dataset_path'):  # 在 根节点的子节点中过滤出标签‘dataset_path’
    valid_dataset_path = node.text
for node in root.iter('learning_rate'):  # 在 根节点的子节点中过滤出标签‘learning_rate’
    learning_rate = float(node.text)


# == step 1 数据处理与读取 ==

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),  # 0-255 归一化到0-1 转Tensor
    transforms.Normalize(norm_mean, norm_std),
])



# 返回所有图片路径和标签
def get_img_info(data_dir):
    data_info = []
    with open(data_dir, 'rU') as file:
        for  i ,line in enumerate(file):
            split_res = line.split(" ", 1)
            path_img = split_res[0]
            label = split_res[1]
            data_info.append((path_img, int(label)))
    return data_info


class LoadDataset(Dataset):
    # 确定数据路径
    def __init__(self, data_dir=None, transform=None):
        self.imgs_labels = get_img_info(data_dir)
        self.transform = transform

    # 获取并返回真实的数据和label
    def __getitem__(self, index):
        img_path,label = self.imgs_labels[index]
        img = Image.open(img_path)
        # img.show()
        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等
        return img, label

    # 确定索引的范围
    def __len__(self):
        return len(self.imgs_labels)

train_dataset = LoadDataset(data_dir=train_dataset_path,transform=train_transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batchsize, shuffle=True)  # shuffle训练时打乱样本

valid_dataset = LoadDataset(data_dir=valid_dataset_path,transform=train_transform)
valid_loader = DataLoader(dataset=train_dataset, batch_size=batchsize)  # shuffle训练时打乱样本

# == step 2 模型 ==
from net.classification.ResNet import ResNet18
net = ResNet18(10,512)  # 对应修改模型 net = se_resnet50(num_classes=5,pretrained=True)

# == step 3 损失函数 ==
criterion = nn.CrossEntropyLoss()

# == step 4 优化器 ==
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)  # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # 设置学习率下降策略,每过step_size个epoch,做一次更新

# == step 5 评测函数==
def evaluteTop1(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
        # correct += torch.eq(pred, y).sum().item()
    return correct / total


def evaluteTop5(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            maxk = max((1, 5))
            y_resize = y.view(-1, 1)
            _, pred = logits.topk(maxk, 1, True, True)
            correct += torch.eq(pred, y_resize).sum().float().item()
    return correct / total

# == step 6 训练 ==

for i in range(epoch):
    # 训练
    print("current_epoch:",i+1)
    best = [0]  # 存储最优指标,用于Early Stopping
    correct = 0
    total_loss = 0
    for idx,data_info in enumerate(train_loader):
        inputs, labels = data_info
        # forward
        outputs = net(inputs)
        # backward
        optimizer.zero_grad()  # 梯度置零,设置在loss之前
        loss = criterion(outputs, labels)  # 一个batch的loss
        total_loss += loss.item()
        loss.backward()  # loss反向传播
        # update weights
        optimizer.step()  # 更新所有的参数
        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)  # 1 返回索引的意思
        correct += (predicted == labels).squeeze().sum().numpy()  # 计算一共正确的个数
    print("loss:",total_loss)
    print("acc:",correct/(len(train_loader)*batchsize))
    scheduler.step()  # 更新学习率
    # 打印当前学习率
    print("当前学习率:",optimizer.state_dict()['param_groups'][0]['lr'])
    if max(best) <= correct/(len(train_loader)*batchsize):
        best.append(correct/(len(train_loader)*batchsize))
        torch.save(net.state_dict(), "best_model.pth")

    # 验证
    val_correct = 0
    if epoch % 5 == 0:
        print("valid")
        for idx, data_info in enumerate(valid_loader):
            inputs, labels = data_info
            # forward
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)  # 1 返回索引的意思
            val_correct += (predicted == labels).squeeze().sum().numpy()  # 计算一共正确的个数
        print("val_acc:", val_correct / (len(valid_loader) * batchsize))

Predict.py

用switch语句进行选择

在这里插入图片描述

在这里插入图片描述

用xml保存中间模型训练参数

<?xml version="1.0" ?>
<train_config>
   <param>
      <epoch>100</epoch>
      <batchsize>2</batchsize>
      <dataset_path>F:\PycharmWorkPlace\ModelTrainingSystem\api\classification\cifar10_dataset.txt</dataset_path>
   </param>
</train_config>

所有系统导入的模型数据集标准都以txt保存,左侧为原始数据,右侧为标注数据,之间用空格来间隔

G:\dataset\split_data\split_data\test\0\0_116.png 0
G:\dataset\split_data\split_data\test\0\0_116.png 0
G:\dataset\split_data\split_data\test\0\0_116.png 0
G:\dataset\split_data\split_data\test\0\0_116.png 0
G:\dataset\split_data\split_data\test\0\0_116.png 0
G:\dataset\split_data\split_data\test\0\0_116.png 1
G:\dataset\split_data\split_data\test\0\0_116.png 1
G:\dataset\split_data\split_data\test\0\0_116.png 1
G:\dataset\split_data\split_data\test\0\0_116.png 1
G:\dataset\split_data\split_data\test\0\0_116.png 1
G:\dataset\split_data\split_data\test\0\0_116.png 2
G:\dataset\split_data\split_data\test\0\0_116.png 2
G:\dataset\split_data\split_data\test\0\0_116.png 2
G:\dataset\split_data\split_data\test\0\0_116.png 2
G:\dataset\split_data\split_data\test\0\0_116.png 2

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值