模
型
训
练
系
统
模型训练系统
模型训练系统
GUI界面采取读取文件的形式,和后端模型训练进行分离
写XML文件
<?xml version="1.0" ?>
<train_config>
<param>
<epoch>100</epoch>
<batchsize>2</batchsize>
<dataset_path>F:\PycharmWorkPlace\ModelTrainingSystem\api\classification\cifar10_dataset.txt</dataset_path>
</param>
</train_config>
Train.py
1.读取xml文件
2.配置各项参数
3.训练
from torch.utils.data import Dataset
from PIL import Image
import os
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import xml.etree.ElementTree as ET
import torch.optim as optim
import torch.nn as nn
import torch
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
norm_mean = [0.33424968, 0.33424437, 0.33428448]
norm_std = [0.24796878, 0.24796101, 0.24801227]
epoch = None
batchsize = None
train_dataset_path = None
valid_dataset_path = None
learning_rate =None
path_saved_model = 'best_model.pth'
tree = ET.parse('train_config.xml')
root = tree.getroot()
for node in root.iter('epoch'):
epoch = int(node.text)
for node in root.iter('batchsize'):
batchsize = int(node.text)
for node in root.iter('train_dataset_path'):
train_dataset_path = node.text
for node in root.iter('valid_dataset_path'):
valid_dataset_path = node.text
for node in root.iter('learning_rate'):
learning_rate = float(node.text)
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
def get_img_info(data_dir):
data_info = []
with open(data_dir, 'rU') as file:
for i ,line in enumerate(file):
split_res = line.split(" ", 1)
path_img = split_res[0]
label = split_res[1]
data_info.append((path_img, int(label)))
return data_info
class LoadDataset(Dataset):
def __init__(self, data_dir=None, transform=None):
self.imgs_labels = get_img_info(data_dir)
self.transform = transform
def __getitem__(self, index):
img_path,label = self.imgs_labels[index]
img = Image.open(img_path)
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs_labels)
train_dataset = LoadDataset(data_dir=train_dataset_path,transform=train_transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batchsize, shuffle=True)
valid_dataset = LoadDataset(data_dir=valid_dataset_path,transform=train_transform)
valid_loader = DataLoader(dataset=train_dataset, batch_size=batchsize)
from net.classification.ResNet import ResNet18
net = ResNet18(10,512)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
def evaluteTop1(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x, y in loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct / total
def evaluteTop5(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x, y in loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
logits = model(x)
maxk = max((1, 5))
y_resize = y.view(-1, 1)
_, pred = logits.topk(maxk, 1, True, True)
correct += torch.eq(pred, y_resize).sum().float().item()
return correct / total
for i in range(epoch):
print("current_epoch:",i+1)
best = [0]
correct = 0
total_loss = 0
for idx,data_info in enumerate(train_loader):
inputs, labels = data_info
outputs = net(inputs)
optimizer.zero_grad()
loss = criterion(outputs, labels)
total_loss += loss.item()
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).squeeze().sum().numpy()
print("loss:",total_loss)
print("acc:",correct/(len(train_loader)*batchsize))
scheduler.step()
print("当前学习率:",optimizer.state_dict()['param_groups'][0]['lr'])
if max(best) <= correct/(len(train_loader)*batchsize):
best.append(correct/(len(train_loader)*batchsize))
torch.save(net.state_dict(), "best_model.pth")
val_correct = 0
if epoch % 5 == 0:
print("valid")
for idx, data_info in enumerate(valid_loader):
inputs, labels = data_info
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
val_correct += (predicted == labels).squeeze().sum().numpy()
print("val_acc:", val_correct / (len(valid_loader) * batchsize))
Predict.py
用switch语句进行选择
用xml保存中间模型训练参数
<?xml version="1.0" ?>
<train_config>
<param>
<epoch>100</epoch>
<batchsize>2</batchsize>
<dataset_path>F:\PycharmWorkPlace\ModelTrainingSystem\api\classification\cifar10_dataset.txt</dataset_path>
</param>
</train_config>
所有系统导入的模型数据集标准都以txt保存,左侧为原始数据,右侧为标注数据,之间用空格来间隔
G:\dataset\split_data\split_data\test\0\0_116.png 0
G:\dataset\split_data\split_data\test\0\0_116.png 0
G:\dataset\split_data\split_data\test\0\0_116.png 0
G:\dataset\split_data\split_data\test\0\0_116.png 0
G:\dataset\split_data\split_data\test\0\0_116.png 0
G:\dataset\split_data\split_data\test\0\0_116.png 1
G:\dataset\split_data\split_data\test\0\0_116.png 1
G:\dataset\split_data\split_data\test\0\0_116.png 1
G:\dataset\split_data\split_data\test\0\0_116.png 1
G:\dataset\split_data\split_data\test\0\0_116.png 1
G:\dataset\split_data\split_data\test\0\0_116.png 2
G:\dataset\split_data\split_data\test\0\0_116.png 2
G:\dataset\split_data\split_data\test\0\0_116.png 2
G:\dataset\split_data\split_data\test\0\0_116.png 2
G:\dataset\split_data\split_data\test\0\0_116.png 2