PyTorch 训练图像分割网络(代码讲解)

整个工程文件已放到Github上
https://github.com/yaoyi30/PyTorch_Image_Segmentation

一、训练图像分类网络主要流程

  1. 构建数据集
  2. 数据预处理、包括数据增强和数据标准化和归一化
  3. 构建网络模型
  4. 设置学习率、优化器、损失函数等超参数
  5. 训练和验证

二、各个流程简要说明

1. 构建数据集

本文使用supervisely 发布的人像分割数据集,百度网盘地址:

https://pan.baidu.com/s/1B8eBqg7XROHOsm5OLw-t9g 提取码: 52ss

在这里插入图片描述

在工程目录下,新建datasets文件夹,在文件夹内分别新建images和labels文件夹,用来放图片和对应的mask图片,之后在两个文件夹内新建train和val文件夹用来存放训练和验证数据,结构如下:

datasets/
  images/    # images
     train/
        img1.jpg
        img2.jpg
         .
         .
         .
     val/
        img1.jpg
        img2.jpg
         .
         .
         .
  labels/     # masks
     train/
        img1.png
        img2.png
         .
         .
         .
     val/
        img1.png
        img2.png
         .
         .
         .

2. 数据预处理

将图像resize到统一大小,之后转为tensor格式再进行标准化,预处理之后的图片可以正常输入网络,对于训练集可以采取一些数据增强手段来增强网络的泛化能力,验证集不做数据增强。

    #训练数据预处理、数据增强设置
    train_transform = Compose([
                                    Resize(args.input_size), #图像resize到统一大小
                                    RandomHorizontalFlip(0.5),  #数据增强,水平翻转
                                    ToTensor(), #转为tensor格式,值变为0-1之间
                                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #标准化
                             ])
    #验证数据预处理
    val_transform = Compose([
                                    Resize(args.input_size), #图像resize到统一大小
                                    ToTensor(), #转为tensor格式,值变为0-1之间
                                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #标准化
                            ])

3. 构建网络模型

本文搭建了一个简单的图像分割网路,命名为Simplify_Net。

    model = Simplify_Net(args.nb_classes)

4. 设置学习率、优化器、损失函数等超参数

    #定义损失函数,因为分割是像素级的分类,因此可以选用交叉熵损失函数
    loss_function = nn.CrossEntropyLoss()
    #定义优化器(初始学习率和权重衰减值)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
    #定义学习率类型,此处选用余弦退火学习率,设置最大值
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.max_lr, total_steps=args.epochs, verbose=True)

5. 训练和验证

	#训练和验证模型,具体函写在了utils/engine.py文件中
	history = train_and_val(args.epochs, model, train_loader,val_loader,loss_function, optimizer,scheduler,args.output_dir,device,args.nb_classes) 

三、工程代码文件详细讲解

train.py

定义训练的入口函数,以及训练所需要的流程

1. 导入相应的库和文件

import os
import torch
import torch.nn as nn
from models.Simplify_Net import Simplify_Net
from utils.engine import train_and_val,plot_pix_acc,plot_miou,plot_loss,plot_lr
import argparse
import numpy as np
from utils.transform import Resize,Compose,ToTensor,Normalize,RandomHorizontalFlip
from utils.datasets import SegData

2. 训练参数设置

def get_args_parser():
    parser = argparse.ArgumentParser('Image Segmentation Train', add_help=False)
    parser.add_argument('--batch_size', default=32, type=int,help='Batch size for training')
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
    parser.add_argument('--data_path', default='./datasets/', type=str,help='dataset path')

    parser.add_argument('--init_lr', default=1e-5, type=float,help='intial lr')
    parser.add_argument('--max_lr', default=1e-3, type=float,help='max lr')
    parser.add_argument('--weight_decay', default=1e-5, type=float,help='weight decay')

    parser.add_argument('--nb_classes', default=2, type=int,help='number of the classification types')
    parser.add_argument('--output_dir', default='./output_dir',help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',help='device to use for training / testing')
    parser.add_argument('--num_workers', default=4, type=int)

    return parser

3. 定义主函数

def main(args):

    device = torch.device(args.device)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    train_transform = Compose([
                                    Resize(args.input_size),
                                    RandomHorizontalFlip(0.5),
                                    ToTensor(),
                                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                             ])

    val_transform = Compose([
                                    Resize(args.input_size),
                                    ToTensor(),
                                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                            ])

    train_dataset = SegData(image_path=os.path.join(args.data_path, 'images/train'),
                            mask_path=os.path.join(args.data_path, 'labels/train'),
                            data_transforms=train_transform)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.num_workers)

    val_dataset = SegData(image_path=os.path.join(args.data_path, 'images/val'),
                            mask_path=os.path.join(args.data_path, 'labels/val'),
                            data_transforms=val_transform)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.num_workers)

    model = Simplify_Net(args.nb_classes)
    loss_function = nn.CrossEntropyLoss()

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.max_lr, total_steps=args.epochs, verbose=True)

    history = train_and_val(args.epochs, model, train_loader,val_loader,loss_function, optimizer,scheduler,args.output_dir,device,args.nb_classes)

    plot_loss(np.arange(0,args.epochs),args.output_dir, history)
    plot_pix_acc(np.arange(0,args.epochs),args.output_dir, history)
    plot_miou(np.arange(0,args.epochs),args.output_dir, history)
    plot_lr(np.arange(0,args.epochs),args.output_dir, history)

4. 开始执行

if __name__ == '__main__':
    #获取训练参数
    args = get_args_parser()
    #解析训练参数
    args = args.parse_args()
    #训练参数传入主函数
    main(args)

运行train.py,训练时打印的信息,包括每一轮的学习率,训练集和验证集指标,运行时间等
在这里插入图片描述

Simplify_Net.py

定义网络结构,本文定义一个简单的Encoder-Decoder结构的卷积神经网络

import torch
import torch.nn as nn
from torch.nn.functional import interpolate

class Simplify_Net(nn.Module):
    def __init__(self, num_classes=2):
        super(Simplify_Net, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,padding=1,stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,padding=1,stride=2)
        self.bn2 = nn.BatchNorm2d(16)
        self.relu2 = nn.ReLU(inplace=True)

        self.conv3 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,padding=1,stride=2)
        self.bn3 = nn.BatchNorm2d(16)
        self.relu3 = nn.ReLU(inplace=True)

        self.upconv1 = nn.ConvTranspose2d(in_channels=16,out_channels=16,kernel_size=4,padding=1,stride=2)
        self.bn4 = nn.BatchNorm2d(16)
        self.relu4 = nn.ReLU(inplace=True)

        self.upconv2 = nn.ConvTranspose2d(in_channels=32,out_channels=16,kernel_size=4,padding=1,stride=2)
        self.bn5 = nn.BatchNorm2d(16)
        self.relu5 = nn.ReLU(inplace=True)

        self.conv_last = nn.Conv2d(in_channels=32,out_channels=num_classes,kernel_size=1,stride=1)


    def forward(self, x):

        x1 = self.relu1(self.bn1(self.conv1(x)))
        x2 = self.relu2(self.bn2(self.conv2(x1)))
        x3 = self.relu3(self.bn3(self.conv3(x2)))

        up1 = torch.cat([x2,self.relu4(self.bn4(self.upconv1(x3)))],dim=1)
        up2 = torch.cat([x1,self.relu5(self.bn5(self.upconv2(up1)))],dim=1)

        up3 = self.conv_last(up2)

        out = interpolate(up3, scale_factor=2, mode='bilinear', align_corners=False)

        return out

utils/datasets.py

定义数据读取的类

import os
from torch.utils.data import Dataset
from PIL import Image


class SegData(Dataset):
    def __init__(self, image_path, mask_path, data_transforms=None):
        self.image_path = image_path
        self.mask_path = mask_path

        self.images = os.listdir(self.image_path)
        self.masks = os.listdir(self.mask_path)
        self.transform = data_transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_filename = self.images[idx]
        mask_filename = image_filename.replace('jpeg','png')

        image = Image.open(os.path.join(self.image_path, image_filename)).convert('RGB')
        mask = Image.open(os.path.join(self.mask_path, mask_filename)).convert('L')

        if self.transform is not None:
            image, mask = self.transform(image ,mask)

        return image, mask

utils/transform.py

定义数据预处理的类

import numpy as np
import random
import torch
from torchvision.transforms import functional as F

# 将img和mask resize到统一大小
class Resize(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target=None):
        image = F.resize(image, self.size)
        if target is not None:
            target = F.resize(target, self.size, interpolation=F.InterpolationMode.NEAREST)
        return image, target

#随机左右翻转
class RandomHorizontalFlip(object):
    def __init__(self, flip_prob):
        self.flip_prob = flip_prob

    def __call__(self, image, target=None):
        if random.random() < self.flip_prob:
            image = F.hflip(image)
            if target is not None:
                target = F.hflip(target)
        return image, target

#标准化
class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target

#img转tensor,值变为0-1之间,label直接转为tensor
class ToTensor(object):
    def __call__(self, image, target):
        image = F.to_tensor(image)
        if target is not None:
            target = torch.as_tensor(np.array(target), dtype=torch.int64)
        return image, target

#不同数据预处理类组合起来
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, mask=None):
        for t in self.transforms:
            image, mask = t(image, mask)
        return image, mask

utils/metrics.py

定义计算像素准确率、MIoU等指标的类

import numpy as np

class Evaluator(object):
    def __init__(self, num_class):
        self.num_class = num_class
        self.confusion_matrix = np.zeros((self.num_class,)*2)
    #计算像素准确率
    def Pixel_Accuracy(self):
        Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return Acc

    def Pixel_Accuracy_Class(self):
        Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        Acc = np.nanmean(Acc)
        return Acc
    #计算每一类IoU和MIoU
    def Mean_Intersection_over_Union(self):
        IoU = np.diag(self.confusion_matrix) / (
                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
                    np.diag(self.confusion_matrix))
        MIoU = np.nanmean(IoU)
        return IoU,MIoU

    def Frequency_Weighted_Intersection_over_Union(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
                    np.diag(self.confusion_matrix))

        FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
        return FWIoU

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
    #加入数据
    def add_batch(self, gt_image, pre_image):
        assert gt_image.shape == pre_image.shape
        self.confusion_matrix += self._generate_matrix(gt_image, pre_image)
    #重置
    def reset(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

utils/engine.py

定义具体的训练、验证以及绘制指标曲线的函数

1. 导入相应的库和文件

import os
import torch
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
from utils.metrics import Evaluator
import numpy as np

2. 训练验证函数

def train_and_val(epochs, model, train_loader, val_loader,criterion, optimizer,scheduler,output_dir,device,nb_classes):

    train_loss = []
    val_loss = []
    train_pix_acc = []
    val_pix_acc = []
    train_miou = []
    val_miou = []
    learning_rate = []
    best_miou = 0

    segmetric_train = Evaluator(nb_classes)
    segmetric_val = Evaluator(nb_classes)

    model.to(device)

    fit_time = time.time()
    for e in range(epochs):

        torch.cuda.empty_cache()
        segmetric_train.reset()
        segmetric_val.reset()

        since = time.time()
        training_loss = 0

        model.train()
        with tqdm(total=len(train_loader)) as pbar:
            for image, label in train_loader:

                image = image.to(device)
                label = label.to(device)

                output = model(image)
                loss = criterion(output, label)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pred = output.data.cpu().numpy()
                label = label.cpu().numpy()
                pred = np.argmax(pred, axis=1)

                training_loss += loss.item()
                segmetric_train.add_batch(label, pred)
                pbar.update(1)

        model.eval()
        validation_loss = 0

        with torch.no_grad():
            with tqdm(total=len(val_loader)) as pb:
                for image, label in val_loader:

                    image = image.to(device)
                    label = label.to(device)

                    output = model(image)
                    loss = criterion(output, label)

                    pred = output.data.cpu().numpy()
                    label = label.cpu().numpy()
                    pred = np.argmax(pred, axis=1)

                    validation_loss += loss.item()
                    segmetric_val.add_batch(label, pred)
                    pb.update(1)

        train_loss.append(training_loss / len(train_loader))
        val_loss.append(validation_loss / len(val_loader))

        train_pix_acc.append(segmetric_train.Pixel_Accuracy())
        val_pix_acc.append(segmetric_val.Pixel_Accuracy())

        train_miou.append(segmetric_train.Mean_Intersection_over_Union()[1])
        val_miou.append(segmetric_val.Mean_Intersection_over_Union()[1])

        learning_rate.append(scheduler.get_last_lr())

        torch.save(model.state_dict(), os.path.join(output_dir,'last.pth'))
        if best_miou < segmetric_val.Mean_Intersection_over_Union()[1]:
            torch.save(model.state_dict(), os.path.join(output_dir,'best.pth'))


        print("Epoch:{}/{}..".format(e + 1, epochs),
              "Train Pix Acc: {:.3f}".format(segmetric_train.Pixel_Accuracy()),
              "Val Pix Acc: {:.3f}".format(segmetric_val.Pixel_Accuracy()),
              "Train MIoU: {:.3f}".format(segmetric_train.Mean_Intersection_over_Union()[1]),
              "Val MIoU: {:.3f}".format(segmetric_val.Mean_Intersection_over_Union()[1]),
              "Train Loss: {:.3f}".format(training_loss / len(train_loader)),
              "Val Loss: {:.3f}".format(validation_loss / len(val_loader)),
              "Time: {:.2f}s".format((time.time() - since)))

        scheduler.step()

    history = {'train_loss': train_loss, 'val_loss': val_loss ,'train_pix_acc': train_pix_acc, 'val_pix_acc': val_pix_acc,'train_miou': train_miou, 'val_miou': val_miou,'lr':learning_rate}
    print('Total time: {:.2f} m'.format((time.time() - fit_time) / 60))

    return history

3. 打印损失值曲线

def plot_loss(x,output_dir, history):
    plt.plot(x, history['val_loss'], label='val', marker='o')
    plt.plot(x, history['train_loss'], label='train', marker='o')
    plt.title('Loss per epoch')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.savefig(os.path.join(output_dir,'loss.png'))
    plt.clf()

在这里插入图片描述

4. 打印像素准确率曲线

def plot_pix_acc(x,output_dir, history):
    plt.plot(x, history['train_pix_acc'], label='train_pix_acc', marker='x')
    plt.plot(x, history['val_pix_acc'], label='val_pix_acc', marker='x')
    plt.title('Pix Acc per epoch')
    plt.ylabel('pixal accuracy')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.savefig(os.path.join(output_dir,'pix_acc.png'))
    plt.clf()

网络结构较为简单,因此像素准确率不是特别的高
在这里插入图片描述

5. 打印MIoU曲线

def plot_miou(x,output_dir, history):
    plt.plot(x, history['train_miou'], label='train_miou', marker='x')
    plt.plot(x, history['val_miou'], label='val_miou', marker='x')
    plt.title('MIoU per epoch')
    plt.ylabel('miou')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.savefig(os.path.join(output_dir,'miou.png'))
    plt.clf()

网络结构较为简单,因此像MIoU不是特别的高
在这里插入图片描述

6. 打印学习率曲线

def plot_lr(x,output_dir,  history):
    plt.plot(x, history['lr'], label='learning_rate', marker='x')
    plt.title('learning rate per epoch')
    plt.ylabel('Learning_rate')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.savefig(os.path.join(output_dir,'learning_rate.png'))
    plt.clf()

从学习率曲线可以看出,约前30轮为warmup阶段,最大学习率为0.001
在这里插入图片描述

predict.py

进行单张图片预测

1. 导入相应的库和文件

import argparse
import torch
import torch.nn as nn
import torchvision.transforms as T
from models.Simplify_Net import Simplify_Net
from PIL import Image

2. 单张预测参数设置

def get_args_parser():
    parser = argparse.ArgumentParser('Predict Image', add_help=False)
    parser.add_argument('--image_path', default='./people-man-model-glasses-46219.jpeg', type=str, metavar='MODEL',help='Name of model to train')
    parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
    parser.add_argument('--weights', default='./output_dir/last.pth', type=str,help='dataset path')
    parser.add_argument('--nb_classes', default=2, type=int,help='number of the classification types')
    parser.add_argument('--device', default='cuda',help='device to use for training / testing')

    return parser

3. 定义主函数

def main(args):
    device = torch.device(args.device)

    image = Image.open(args.image_path).convert('RGB')
    img_size = image.size

    transforms = T.Compose([
        T.Resize(args.input_size),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
    ])

    model = Simplify_Net(args.nb_classes)

    checkpoint = torch.load(args.weights, map_location='cpu')
    msg = model.load_state_dict(checkpoint, strict=True)
    print(msg)

    model.to(device)
    model.eval()

    input_tensor = transforms(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        pred = output.argmax(1).squeeze(0).cpu().numpy().astype(np.uint8)

    mask = Image.fromarray(pred)
    out = mask.resize(img_size)
    out.save("result.png")

4. 开始执行

if __name__ == '__main__':
    #获取训练参数
    args = get_args_parser()
    #解析训练参数
    args = args.parse_args()
    #训练参数传入主函数
    main(args)

运行predict.py,保存模型预测的结果
在这里插入图片描述

eval.py

进行模型整体指标评价

1. 导入相应的库和文件

import argparse
from utils.transform import Resize,Compose,ToTensor,Normalize,RandomHorizontalFlip
from utils.datasets import SegData
import torch
import os
import numpy as np
from tqdm import tqdm
from models.Simplify_Net import Simplify_Net
from utils.metrics import Evaluator

2. 模型评价参数设置

def get_args_parser():
    parser = argparse.ArgumentParser('Eval Model', add_help=False)
    parser.add_argument('--batch_size', default=1, type=int,help='Batch size for training')
    parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
    parser.add_argument('--data_path', default='./datasets/', type=str,help='dataset path')
    parser.add_argument('--weights', default='./output_dir/best.pth', type=str,help='dataset path')
    parser.add_argument('--nb_classes', default=2, type=int,help='number of the classification types')
    parser.add_argument('--device', default='cuda',help='device to use for training / testing')
    parser.add_argument('--num_workers', default=4, type=int)

    return parser

3. 定义主函数

def main(args):

    device = torch.device(args.device)

    segmetric = Evaluator(args.nb_classes)
    segmetric.reset()

    val_transform = Compose([
                                    Resize(args.input_size),
                                    ToTensor(),
                                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                            ])

    val_dataset = SegData(image_path=os.path.join(args.data_path, 'images/val'),
                            mask_path=os.path.join(args.data_path, 'labels/val'),
                            data_transforms=val_transform)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.num_workers)

    model = Simplify_Net(args.nb_classes)

    checkpoint = torch.load(args.weights, map_location='cpu')
    msg = model.load_state_dict(checkpoint, strict=True)
    print(msg)

    model.to(device)
    model.eval()

    classes = ["background","human"]

    with torch.no_grad():
        with tqdm(total=len(val_loader)) as pbar:
            for image, label in val_loader:
                output = model(image.to(device))
                pred = output.data.cpu().numpy()
                label = label.cpu().numpy()
                pred = np.argmax(pred, axis=1)
                segmetric.add_batch(label, pred)
                pbar.update(1)

    pix_acc = segmetric.Pixel_Accuracy()
    every_iou,miou = segmetric.Mean_Intersection_over_Union()

    print("Pixel Accuracy is :", pix_acc)
    print("==========Every IOU==========")
    for name,prob in zip(classes,every_iou):
        print(name+" : "+str(prob))
    print("=============================")
    print("MiOU is :", miou)

4. 开始执行

if __name__ == '__main__':
    #获取训练参数
    args = get_args_parser()
    #解析训练参数
    args = args.parse_args()
    #训练参数传入主函数
    main(args)

运行eval.py,打印模型在验证集上的像素准确率,MIoU值和每一类的IoU值
在这里插入图片描述

export_onnx.py

将训练好的模型转onnx格式,以进行后续应用

1. 导入相应的库和文件

import torch
from models.Simplify_Net import Simplify_Net
import argparse

2. 转onnx模型参数设置

def get_args_parser():
    parser = argparse.ArgumentParser('Export Onnx', add_help=False)
    parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
    parser.add_argument('--weights', default='./output_dir/best.pth', type=str,help='dataset path')
    parser.add_argument('--nb_classes', default=2, type=int,help='number of the classification types')

    return parser

3. 定义主函数

def main(args):

    x = torch.randn(1, 3, args.input_size[0],args.input_size[1])
    input_names = ["input"]
    out_names = ["output"]

    model = Simplify_Net(args.nb_classes)

    checkpoint = torch.load(args.weights, map_location='cpu')
    msg = model.load_state_dict(checkpoint, strict=True)
    print(msg)

    model.eval()

    torch.onnx.export(model, x, args.weights.replace('pth','onnx'), export_params=True, training=False, input_names=input_names, output_names=out_names)
    print('please run: python -m onnxsim test.onnx test_sim.onnx\n')

4. 开始执行

if __name__ == '__main__':
    #获取训练参数
    args = get_args_parser()
    #解析训练参数
    args = args.parse_args()
    #训练参数传入主函数
    main(args)

运行export_onnx.py,之后进行模型的简化
在这里插入图片描述

简化之前(左)和之后(右)的onnx模型结构对比
在这里插入图片描述

  • 19
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
query.prepare("INSERT INTO users (username, password) VALUES (:username, :password)"); query.bindValue(":username", username); query.bindValue(":password", password); if (query.exec()) { QMessageBox::information(this, "成功", "注册成功,请图像分割推理代码解析需要结合具体的算法和框架进行讲解,下面以常用登录!"); } else { QMessageBox::critical(this, "错误", "注册失败:" + query.lastError().text()); } 的语义分割算法DeepLabv3+和常见的深度学习框架PyTorch为例进行简 } // 显示添加事件界面 void showAddEvent() { addEventNameEdit->clear(); addEventLocationEdit->clear(); addEventReminderCheck->setChecked(false); addEventStartDateEdit->setDate(QDate::currentDate()); 要说明。 DeepLabv3+是一种深度学习语义分割模型,它基于全卷积网络 addEventStartTimeEdit->setTime(QTime::currentTime()); addEventEndDateEdit->setDate(QDate::currentDate()); 结构和空间金字塔池化模块,能够对图像进行像素级别的分类和分割。 addEventEndTimeEdit->setTime(QTime::currentTime()); addEventWidget->setVisible(true); setCentralWidget(addEventWidget); } // 保存事件 void saveEvent() { QString name = addEventNameEdit->text(); QString location在PyTorch框架下,它的推理代码大致分为以下几个步骤: 1.导入模型 = addEventLocationEdit->text(); QDateTime startDateTime = QDateTime(addEventStartDateEdit->date(), addEventStartTimeEdit->time()); 和预处理函数 ```python import torch from torchvision import transforms from models.deeplabv3plus import DeepLabV QDateTime endDateTime = QDateTime(addEventEndDateEdit->date(), addEventEndTimeEdit->time()); bool reminder = addEventReminderCheck3Plus # 加载模型 model = DeepLabV3Plus(num_classes=21, backbone='resnet101', output_stride->isChecked(); QSqlQuery query; query.prepare("INSERT INTO events (name, start_time, end_time, location, reminder,=16) model.load_state_dict(torch.load('deeplabv3plus_resnet101.pth')) model.eval() # 预处理函数 username) " "VALUES (:name, :start_time, :end_time, :location, :reminder, :username)"); querypreprocess = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,.bindValue(":name", name); query.bindValue(":start_time", startDateTime); query.bindValue(":end_time", endDateTime); query.bindValue(":location", location); query.bindValue(":reminder", reminder); query.bindValue(":username", currentUser); 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) ``` if (query.exec()) { QMessageBox::information(this, "成功", "事件添加成功!"); addEventWidget->setVisible(false2.读取图像并进行预处理 ```python from PIL import Image # 读取图像并进行预处理); eventListWidget->setVisible(true); setCentralWidget(eventListWidget); updateEventList(); } else { QMessageBox image = Image.open('test.jpg') image = preprocess(image).unsqueeze(0) # 将图像输入模型进行推理 ::critical(this, "错误", "事件添加失败:" + query.lastError().text()); } } // 编辑事件 void editEvent() { if (eventListTable->selectedItems().isEmpty()) { QMessageBox::warning(this, "错误with torch.no_grad(): output = model(image)['out'] ``` 3.后处理和可视化 ```python import numpy", "请选择要编辑的事件!"); return; } int row = eventListTable->selectedItems().at(0)-> as np import matplotlib.pyplot as plt # 对输出进行后处理 output = output.squeeze(0) output = torch.argmax(output,row(); int id = eventListTable->item(row, 0)->text().toInt(); QSqlQuery query; query.prepare("SELECT * FROM events WHERE id = :id"); query.bindValue(":id", id); if (query.exec() && query.next()) dim=0).numpy() # 可视化结果 plt.imshow(output) plt.show() ``` 以上是DeepLabv3+在 { addEventNameEdit->setText(query.value("name").toString()); addEventLocationEdit->setText(query.value("location").toString()); addEventReminderCheck->setChecked(query.value("reminder").toBool()); addEventStartDateEdit->setDateTime(query.value("startPyTorch框架下的推理代码,其中包含了模型加载、预处理、推理和后处理等步_time").toDateTime()); addEventEndDateEdit->setDateTime(query.value("end_time").toDateTime()); addEventWidget->setVisible骤。在实际应用中,还需要对代码进行优化和加速,以提高推理速度和效率。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

姚先生97

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值