PyTorch 训练图像分割网络(代码讲解)

整个工程文件已放到Github上
https://github.com/yaoyi30/PyTorch_Image_Segmentation

一、训练图像分类网络主要流程

  1. 构建数据集
  2. 数据预处理、包括数据增强和数据标准化和归一化
  3. 构建网络模型
  4. 设置学习率、优化器、损失函数等超参数
  5. 训练和验证

二、各个流程简要说明

1. 构建数据集

本文使用supervisely 发布的人像分割数据集,百度网盘地址:

https://pan.baidu.com/s/1B8eBqg7XROHOsm5OLw-t9g 提取码: 52ss

在这里插入图片描述

在工程目录下,新建datasets文件夹,在文件夹内分别新建images和labels文件夹,用来放图片和对应的mask图片,之后在两个文件夹内新建train和val文件夹用来存放训练和验证数据,结构如下:

datasets/
  images/    # images
     train/
        img1.jpg
        img2.jpg
         .
         .
         .
     val/
        img1.jpg
        img2.jpg
         .
         .
         .
  labels/     # masks
     train/
        img1.png
        img2.png
         .
         .
         .
     val/
        img1.png
        img2.png
         .
         .
         .

2. 数据预处理

将图像resize到统一大小,之后转为tensor格式再进行标准化,预处理之后的图片可以正常输入网络,对于训练集可以采取一些数据增强手段来增强网络的泛化能力,验证集不做数据增强。

    #训练数据预处理、数据增强设置
    train_transform = Compose([
                                    Resize(args.input_size), #图像resize到统一大小
                                    RandomHorizontalFlip(0.5),  #数据增强,水平翻转
                                    ToTensor(), #转为tensor格式,值变为0-1之间
                                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #标准化
                             ])
    #验证数据预处理
    val_transform = Compose([
                                    Resize(args.input_size), #图像resize到统一大小
                                    ToTensor(), #转为tensor格式,值变为0-1之间
                                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #标准化
                            ])

3. 构建网络模型

本文搭建了一个简单的图像分割网路,命名为Simplify_Net。

    model = Simplify_Net(args.nb_classes)

4. 设置学习率、优化器、损失函数等超参数

    #定义损失函数,因为分割是像素级的分类,因此可以选用交叉熵损失函数
    loss_function = nn.CrossEntropyLoss()
    #定义优化器(初始学习率和权重衰减值)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
    #定义学习率类型,此处选用余弦退火学习率,设置最大值
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.max_lr, total_steps=args.epochs, verbose=True)

5. 训练和验证

	#训练和验证模型,具体函写在了utils/engine.py文件中
	history = train_and_val(args.epochs, model, train_loader,val_loader,loss_function, optimizer,scheduler,args.output_dir,device,args.nb_classes) 

三、工程代码文件详细讲解

train.py

定义训练的入口函数,以及训练所需要的流程

1. 导入相应的库和文件

import os
import torch
import torch.nn as nn
from models.Simplify_Net import Simplify_Net
from utils.engine import train_and_val,plot_pix_acc,plot_miou,plot_loss,plot_lr
import argparse
import numpy as np
from utils.transform import Resize,Compose,ToTensor,Normalize,RandomHorizontalFlip
from utils.datasets import SegData

2. 训练参数设置

def get_args_parser():
    parser = argparse.ArgumentParser('Image Segmentation Train', add_help=False)
    parser.add_argument('--batch_size', default=32, type=int,help='Batch size for training')
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
    parser.add_argument('--data_path', default='./datasets/', type=str,help='dataset path')

    parser.add_argument('--init_lr', default=1e-5, type=float,help='intial lr')
    parser.add_argument('--max_lr', default=1e-3, type=float,help='max lr')
    parser.add_argument('--weight_decay', default=1e-5, type=float,help='weight decay')

    parser.add_argument('--nb_classes', default=2, type=int,help='number of the classification types')
    parser.add_argument('--output_dir', default='./output_dir',help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',help='device to use for training / testing')
    parser.add_argument('--num_workers', default=4, type=int)

    return parser

3. 定义主函数

def main(args):

    device = torch.device(args.device)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    train_transform = Compose([
                                    Resize(args.input_size),
                                    RandomHorizontalFlip(0.5),
                                    ToTensor(),
                                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                             ])

    val_transform = Compose([
                                    Resize(args.input_size),
                                    ToTensor(),
                                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                            ])

    train_dataset = SegData(image_path=os.path.join(args.data_path, 'images/train'),
                            mask_path=os.path.join(args.data_path, 'labels/train'),
                            data_transforms=train_transform)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.num_workers)

    val_dataset = SegData(image_path=os.path.join(args.data_path, 'images/val'),
                            mask_path=os.path.join(args.data_path, 'labels/val'),
                            data_transforms=val_transform)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.num_workers)

    model = Simplify_Net(args.nb_classes)
    loss_function = nn.CrossEntropyLoss()

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.max_lr, total_steps=args.epochs, verbose=True)

    history = train_and_val(args.epochs, model, train_loader,val_loader,loss_function, optimizer,scheduler,args.output_dir,device,args.nb_classes)

    plot_loss(np.arange(0,args.epochs),args.output_dir, history)
    plot_pix_acc(np.arange(0,args.epochs),args.output_dir, history)
    plot_miou(np.arange(0,args.epochs),args.output_dir, history)
    plot_lr(np.arange(0,args.epochs),args.output_dir, history)

4. 开始执行

if __name__ == '__main__':
    #获取训练参数
    args = get_args_parser()
    #解析训练参数
    args = args.parse_args()
    #训练参数传入主函数
    main(args)

运行train.py,训练时打印的信息,包括每一轮的学习率,训练集和验证集指标,运行时间等
在这里插入图片描述

Simplify_Net.py

定义网络结构,本文定义一个简单的Encoder-Decoder结构的卷积神经网络

import torch
import torch.nn as nn
from torch.nn.functional import interpolate

class Simplify_Net(nn.Module):
    def __init__(self, num_classes=2):
        super(Simplify_Net, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,padding=1,stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,padding=1,stride=2)
        self.bn2 = nn.BatchNorm2d(16)
        self.relu2 = nn.ReLU(inplace=True)

        self.conv3 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,padding=1,stride=2)
        self.bn3 = nn.BatchNorm2d(16)
        self.relu3 = nn.ReLU(inplace=True)

        self.upconv1 = nn.ConvTranspose2d(in_channels=16,out_channels=16,kernel_size=4,padding=1,stride=2)
        self.bn4 = nn.BatchNorm2d(16)
        self.relu4 = nn.ReLU(inplace=True)

        self.upconv2 = nn.ConvTranspose2d(in_channels=32,out_channels=16,kernel_size=4,padding=1,stride=2)
        self.bn5 = nn.BatchNorm2d(16)
        self.relu5 = nn.ReLU(inplace=True)

        self.conv_last = nn.Conv2d(in_channels=32,out_channels=num_classes,kernel_size=1,stride=1)


    def forward(self, x):

        x1 = self.relu1(self.bn1(self.conv1(x)))
        x2 = self.relu2(self.bn2(self.conv2(x1)))
        x3 = self.relu3(self.bn3(self.conv3(x2)))

        up1 = torch.cat([x2,self.relu4(self.bn4(self.upconv1(x3)))],dim=1)
        up2 = torch.cat([x1,self.relu5(self.bn5(self.upconv2(up1)))],dim=1)

        up3 = self.conv_last(up2)

        out = interpolate(up3, scale_factor=2, mode='bilinear', align_corners=False)

        return out

utils/datasets.py

定义数据读取的类

import os
from torch.utils.data import Dataset
from PIL import Image


class SegData(Dataset):
    def __init__(self, image_path, mask_path, data_transforms=None):
        self.image_path = image_path
        self.mask_path = mask_path

        self.images = os.listdir(self.image_path)
        self.masks = os.listdir(self.mask_path)
        self.transform = data_transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_filename = self.images[idx]
        mask_filename = image_filename.replace('jpeg','png')

        image = Image.open(os.path.join(self.image_path, image_filename)).convert('RGB')
        mask = Image.open(os.path.join(self.mask_path, mask_filename)).convert('L')

        if self.transform is not None:
            image, mask = self.transform(image ,mask)

        return image, mask

utils/transform.py

定义数据预处理的类

import numpy as np
import random
import torch
from torchvision.transforms import functional as F

# 将img和mask resize到统一大小
class Resize(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target=None):
        image = F.resize(image, self.size)
        if target is not None:
            target = F.resize(target, self.size, interpolation=F.InterpolationMode.NEAREST)
        return image, target

#随机左右翻转
class RandomHorizontalFlip(object):
    def __init__(self, flip_prob):
        self.flip_prob = flip_prob

    def __call__(self, image, target=None):
        if random.random() < self.flip_prob:
            image = F.hflip(image)
            if target is not None:
                target = F.hflip(target)
        return image, target

#标准化
class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target

#img转tensor,值变为0-1之间,label直接转为tensor
class ToTensor(object):
    def __call__(self, image, target):
        image = F.to_tensor(image)
        if target is not None:
            target = torch.as_tensor(np.array(target), dtype=torch.int64)
        return image, target

#不同数据预处理类组合起来
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, mask=None):
        for t in self.transforms:
            image, mask = t(image, mask)
        return image, mask

utils/metrics.py

定义计算像素准确率、MIoU等指标的类

import numpy as np

class Evaluator(object):
    def __init__(self, num_class):
        self.num_class = num_class
        self.confusion_matrix = np.zeros((self.num_class,)*2)
    #计算像素准确率
    def Pixel_Accuracy(self):
        Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return Acc

    def Pixel_Accuracy_Class(self):
        Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        Acc = np.nanmean(Acc)
        return Acc
    #计算每一类IoU和MIoU
    def Mean_Intersection_over_Union(self):
        IoU = np.diag(self.confusion_matrix) / (
                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
                    np.diag(self.confusion_matrix))
        MIoU = np.nanmean(IoU)
        return IoU,MIoU

    def Frequency_Weighted_Intersection_over_Union(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
                    np.diag(self.confusion_matrix))

        FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
        return FWIoU

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
    #加入数据
    def add_batch(self, gt_image, pre_image):
        assert gt_image.shape == pre_image.shape
        self.confusion_matrix += self._generate_matrix(gt_image, pre_image)
    #重置
    def reset(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

utils/engine.py

定义具体的训练、验证以及绘制指标曲线的函数

1. 导入相应的库和文件

import os
import torch
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
from utils.metrics import Evaluator
import numpy as np

2. 训练验证函数

def train_and_val(epochs, model, train_loader, val_loader,criterion, optimizer,scheduler,output_dir,device,nb_classes):

    train_loss = []
    val_loss = []
    train_pix_acc = []
    val_pix_acc = []
    train_miou = []
    val_miou = []
    learning_rate = []
    best_miou = 0

    segmetric_train = Evaluator(nb_classes)
    segmetric_val = Evaluator(nb_classes)

    model.to(device)

    fit_time = time.time()
    for e in range(epochs):

        torch.cuda.empty_cache()
        segmetric_train.reset()
        segmetric_val.reset()

        since = time.time()
        training_loss = 0

        model.train()
        with tqdm(total=len(train_loader)) as pbar:
            for image, label in train_loader:

                image = image.to(device)
                label = label.to(device)

                output = model(image)
                loss = criterion(output, label)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pred = output.data.cpu().numpy()
                label = label.cpu().numpy()
                pred = np.argmax(pred, axis=1)

                training_loss += loss.item()
                segmetric_train.add_batch(label, pred)
                pbar.update(1)

        model.eval()
        validation_loss = 0

        with torch.no_grad():
            with tqdm(total=len(val_loader)) as pb:
                for image, label in val_loader:

                    image = image.to(device)
                    label = label.to(device)

                    output = model(image)
                    loss = criterion(output, label)

                    pred = output.data.cpu().numpy()
                    label = label.cpu().numpy()
                    pred = np.argmax(pred, axis=1)

                    validation_loss += loss.item()
                    segmetric_val.add_batch(label, pred)
                    pb.update(1)

        train_loss.append(training_loss / len(train_loader))
        val_loss.append(validation_loss / len(val_loader))

        train_pix_acc.append(segmetric_train.Pixel_Accuracy())
        val_pix_acc.append(segmetric_val.Pixel_Accuracy())

        train_miou.append(segmetric_train.Mean_Intersection_over_Union()[1])
        val_miou.append(segmetric_val.Mean_Intersection_over_Union()[1])

        learning_rate.append(scheduler.get_last_lr())

        torch.save(model.state_dict(), os.path.join(output_dir,'last.pth'))
        if best_miou < segmetric_val.Mean_Intersection_over_Union()[1]:
            torch.save(model.state_dict(), os.path.join(output_dir,'best.pth'))


        print("Epoch:{}/{}..".format(e + 1, epochs),
              "Train Pix Acc: {:.3f}".format(segmetric_train.Pixel_Accuracy()),
              "Val Pix Acc: {:.3f}".format(segmetric_val.Pixel_Accuracy()),
              "Train MIoU: {:.3f}".format(segmetric_train.Mean_Intersection_over_Union()[1]),
              "Val MIoU: {:.3f}".format(segmetric_val.Mean_Intersection_over_Union()[1]),
              "Train Loss: {:.3f}".format(training_loss / len(train_loader)),
              "Val Loss: {:.3f}".format(validation_loss / len(val_loader)),
              "Time: {:.2f}s".format((time.time() - since)))

        scheduler.step()

    history = {'train_loss': train_loss, 'val_loss': val_loss ,'train_pix_acc': train_pix_acc, 'val_pix_acc': val_pix_acc,'train_miou': train_miou, 'val_miou': val_miou,'lr':learning_rate}
    print('Total time: {:.2f} m'.format((time.time() - fit_time) / 60))

    return history

3. 打印损失值曲线

def plot_loss(x,output_dir, history):
    plt.plot(x, history['val_loss'], label='val', marker='o')
    plt.plot(x, history['train_loss'], label='train', marker='o')
    plt.title('Loss per epoch')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.savefig(os.path.join(output_dir,'loss.png'))
    plt.clf()

在这里插入图片描述

4. 打印像素准确率曲线

def plot_pix_acc(x,output_dir, history):
    plt.plot(x, history['train_pix_acc'], label='train_pix_acc', marker='x')
    plt.plot(x, history['val_pix_acc'], label='val_pix_acc', marker='x')
    plt.title('Pix Acc per epoch')
    plt.ylabel('pixal accuracy')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.savefig(os.path.join(output_dir,'pix_acc.png'))
    plt.clf()

网络结构较为简单,因此像素准确率不是特别的高
在这里插入图片描述

5. 打印MIoU曲线

def plot_miou(x,output_dir, history):
    plt.plot(x, history['train_miou'], label='train_miou', marker='x')
    plt.plot(x, history['val_miou'], label='val_miou', marker='x')
    plt.title('MIoU per epoch')
    plt.ylabel('miou')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.savefig(os.path.join(output_dir,'miou.png'))
    plt.clf()

网络结构较为简单,因此像MIoU不是特别的高
在这里插入图片描述

6. 打印学习率曲线

def plot_lr(x,output_dir,  history):
    plt.plot(x, history['lr'], label='learning_rate', marker='x')
    plt.title('learning rate per epoch')
    plt.ylabel('Learning_rate')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.savefig(os.path.join(output_dir,'learning_rate.png'))
    plt.clf()

从学习率曲线可以看出,约前30轮为warmup阶段,最大学习率为0.001
在这里插入图片描述

predict.py

进行单张图片预测

1. 导入相应的库和文件

import argparse
import torch
import torch.nn as nn
import torchvision.transforms as T
from models.Simplify_Net import Simplify_Net
from PIL import Image

2. 单张预测参数设置

def get_args_parser():
    parser = argparse.ArgumentParser('Predict Image', add_help=False)
    parser.add_argument('--image_path', default='./people-man-model-glasses-46219.jpeg', type=str, metavar='MODEL',help='Name of model to train')
    parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
    parser.add_argument('--weights', default='./output_dir/last.pth', type=str,help='dataset path')
    parser.add_argument('--nb_classes', default=2, type=int,help='number of the classification types')
    parser.add_argument('--device', default='cuda',help='device to use for training / testing')

    return parser

3. 定义主函数

def main(args):
    device = torch.device(args.device)

    image = Image.open(args.image_path).convert('RGB')
    img_size = image.size

    transforms = T.Compose([
        T.Resize(args.input_size),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
    ])

    model = Simplify_Net(args.nb_classes)

    checkpoint = torch.load(args.weights, map_location='cpu')
    msg = model.load_state_dict(checkpoint, strict=True)
    print(msg)

    model.to(device)
    model.eval()

    input_tensor = transforms(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        pred = output.argmax(1).squeeze(0).cpu().numpy().astype(np.uint8)

    mask = Image.fromarray(pred)
    out = mask.resize(img_size)
    out.save("result.png")

4. 开始执行

if __name__ == '__main__':
    #获取训练参数
    args = get_args_parser()
    #解析训练参数
    args = args.parse_args()
    #训练参数传入主函数
    main(args)

运行predict.py,保存模型预测的结果
在这里插入图片描述

eval.py

进行模型整体指标评价

1. 导入相应的库和文件

import argparse
from utils.transform import Resize,Compose,ToTensor,Normalize,RandomHorizontalFlip
from utils.datasets import SegData
import torch
import os
import numpy as np
from tqdm import tqdm
from models.Simplify_Net import Simplify_Net
from utils.metrics import Evaluator

2. 模型评价参数设置

def get_args_parser():
    parser = argparse.ArgumentParser('Eval Model', add_help=False)
    parser.add_argument('--batch_size', default=1, type=int,help='Batch size for training')
    parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
    parser.add_argument('--data_path', default='./datasets/', type=str,help='dataset path')
    parser.add_argument('--weights', default='./output_dir/best.pth', type=str,help='dataset path')
    parser.add_argument('--nb_classes', default=2, type=int,help='number of the classification types')
    parser.add_argument('--device', default='cuda',help='device to use for training / testing')
    parser.add_argument('--num_workers', default=4, type=int)

    return parser

3. 定义主函数

def main(args):

    device = torch.device(args.device)

    segmetric = Evaluator(args.nb_classes)
    segmetric.reset()

    val_transform = Compose([
                                    Resize(args.input_size),
                                    ToTensor(),
                                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                            ])

    val_dataset = SegData(image_path=os.path.join(args.data_path, 'images/val'),
                            mask_path=os.path.join(args.data_path, 'labels/val'),
                            data_transforms=val_transform)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.num_workers)

    model = Simplify_Net(args.nb_classes)

    checkpoint = torch.load(args.weights, map_location='cpu')
    msg = model.load_state_dict(checkpoint, strict=True)
    print(msg)

    model.to(device)
    model.eval()

    classes = ["background","human"]

    with torch.no_grad():
        with tqdm(total=len(val_loader)) as pbar:
            for image, label in val_loader:
                output = model(image.to(device))
                pred = output.data.cpu().numpy()
                label = label.cpu().numpy()
                pred = np.argmax(pred, axis=1)
                segmetric.add_batch(label, pred)
                pbar.update(1)

    pix_acc = segmetric.Pixel_Accuracy()
    every_iou,miou = segmetric.Mean_Intersection_over_Union()

    print("Pixel Accuracy is :", pix_acc)
    print("==========Every IOU==========")
    for name,prob in zip(classes,every_iou):
        print(name+" : "+str(prob))
    print("=============================")
    print("MiOU is :", miou)

4. 开始执行

if __name__ == '__main__':
    #获取训练参数
    args = get_args_parser()
    #解析训练参数
    args = args.parse_args()
    #训练参数传入主函数
    main(args)

运行eval.py,打印模型在验证集上的像素准确率,MIoU值和每一类的IoU值
在这里插入图片描述

export_onnx.py

将训练好的模型转onnx格式,以进行后续应用

1. 导入相应的库和文件

import torch
from models.Simplify_Net import Simplify_Net
import argparse

2. 转onnx模型参数设置

def get_args_parser():
    parser = argparse.ArgumentParser('Export Onnx', add_help=False)
    parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
    parser.add_argument('--weights', default='./output_dir/best.pth', type=str,help='dataset path')
    parser.add_argument('--nb_classes', default=2, type=int,help='number of the classification types')

    return parser

3. 定义主函数

def main(args):

    x = torch.randn(1, 3, args.input_size[0],args.input_size[1])
    input_names = ["input"]
    out_names = ["output"]

    model = Simplify_Net(args.nb_classes)

    checkpoint = torch.load(args.weights, map_location='cpu')
    msg = model.load_state_dict(checkpoint, strict=True)
    print(msg)

    model.eval()

    torch.onnx.export(model, x, args.weights.replace('pth','onnx'), export_params=True, training=False, input_names=input_names, output_names=out_names)
    print('please run: python -m onnxsim test.onnx test_sim.onnx\n')

4. 开始执行

if __name__ == '__main__':
    #获取训练参数
    args = get_args_parser()
    #解析训练参数
    args = args.parse_args()
    #训练参数传入主函数
    main(args)

运行export_onnx.py,之后进行模型的简化
在这里插入图片描述

简化之前(左)和之后(右)的onnx模型结构对比
在这里插入图片描述

### 华为OD机考数大雁真题及答案解析 #### 题目描述 给定一个字符串 `croakOfFrogs`,表示不同时间点听到的大雁叫声。每只大雁发出的声音序列严格遵循 "quack" 的顺序。返回能够产生所给字符串的最少大雁数量。如果该字符串不是有效的组合,则返回 `-1`。 条件如下: - 输入字符串长度范围:\( 1 \leq croakOfFrogs.length \leq 10^5 \) - 字符串中的字符仅限于 'q', 'u', 'a', 'c' 或者 'k' #### 解决方案 为了计算最小的大雁数量,可以维护五个计数器来跟踪当前正在发声的不同阶段的大雁数目。每当遇到一个新的起始字母(即 'q'),增加相应计数器;当完成一次完整的 “quack” 声音循环时减少这些计数器。还需要确保任何时候后面的字母不会超过前面的字母的数量,否则就不是一个合法的输入[^1]。 下面是具体的实现方法: ```cpp class Solution { public: int minNumberOfGeese(string croakOfGeese) { unordered_map<char, int> count{{'q', 0}, {'u', 0}, {'a', 0}, {'c', 0}, {'k', 0}}; int max_geese = 0; for (char ch : croakOfGeese) { ++count[ch]; // Check the order of characters to ensure validity. if (!(count['q'] >= count['u'] && count['u'] >= count['a'] && count['a'] >= count['c'] && count['c'] >= count['k'])) { return -1; } // Update maximum number of geese at any point in time. max_geese = std::max(max_geese, *std::max_element(count.begin(), count.end(), [](const auto& p1, const auto& p2) { return p1.second < p2.second; })); // When a full sequence is completed ('quack'), decrement all counters by one. if (ch == 'k') { for (auto& pair : count) { --pair.second; } } } // Ensure no incomplete sequences are left over. for (int val : count.values()) { if (val != 0) return -1; } return max_geese; } }; ``` 此代码通过遍历整个字符串并保持对每个声音部分的追踪来解决问题。它还验证了每次读取新字符后的合法性,并在检测到完整的一轮发音后重置计数器。最后检查是否有未完成的序列存在,如果有则返回错误码 `-1`,否则返回最大并发大雁数量作为结果[^3]。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

姚先生97

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值