PyTorch实战:DCGAN+MobileNet V3实现小样本数据集分类

提示:该思路和代码可以用于大部分小样本数据集分类


前言

1. 个人感触

近期参加了2023年第四届MathorCup高校数学建模挑战赛——大数据竞赛,一方面是因为好友的邀请,另一方面也是为了检验前段时间PyTorch的学习效果。由于一些原因,原本应该进行七天的比赛我们只开展了三天。不过好在整个过程中我的模型搭建很顺利,虽然结果可能不太理想,但总体而言对我来说是一个不错的体验(感受到了自身的成长)。

2. 题目分析

本次我们选择的是A题,目的是进行道路坑洼的分类。我认为该题目的难点在于训练集样本少(300张正常道路,100张坑洼道路),测试集样本大(4000张)。同时,训练集中图像尺寸不一致,很多图像中道路的占比较少。我们通过DCGAN网络对训练集进行数据增强,通过增强后训练集对MobileNet V3网络进行训练。我们采用的是双层MobileNet V3网络(使用Large和Small两个版本)对测试集进行一个预测,以下是预测的伪代码:

root = 原始数据所在路径

for filename in os.listdir(root)  # 遍历路径下的图片
   crops = Flips(filename)  # 将图片进行裁剪

   for img in range(crops)  # 遍历裁剪后的图片
       predict1 = MobileNetV3_Large(img)  # 对裁剪后的图像进行第一层预测

       if predict1 is road:  # 预测结果为有道路
           predict2 = MobileNetV3_Small(img)

           if predict2 is potholes:  # 预测结果为有坑洼
               label(img) = potholes  # img的标签为potholes
               label(filename) = potholes  # filename的标签为potholes

           else:
               label(img) = normal  # img的标签为normal

       else:
           label(img) = useless  # img的标签为useless

   if all(label(img)) is (normal or useless):  # 如果所有裁剪后图片标签均为normal或useless
       label(filename) = normal  # filename的标签为normal

save('test_result.csv')  # 保存结果至csv文件

一、准备工作

1. 图像预处理

由于训练集中的图像大小并非一致,对于图像尺寸较大的图像,我们需要将其裁剪为224×224的大小;对于图像尺寸较小的图像,我们需要将其Resize为224×224的大小。以下为Filps.py函数:

import os
import cv2 as cv
from PIL import Image
from torchvision import transforms


def crop_image(image_dir, output_path, size):
    for filename in os.listdir(image_dir):
        name, _ = filename.split('.')  # name = 原始图像名称
        file_path = os.path.join(image_dir, filename)  # 原始图像所在路径
        image = cv.imread(file_path)
        transform = transforms.Resize([size, size])
        h, w = image.shape[0:2]
        i = 0

        if h < size or w < size:
            image = Image.fromarray(image)  # 将其转为ndarray格式
            image = transform(image)  # Resize图像
            image.save(output_path + name + f"_{i}" + ".png")
        else:
            h_no = h // size
            w_no = w // size

            for row in range(0, h_no):
                for col in range(0, w_no):
                    cropped_img = image[size*row: size*(row+1), size*col: size*(col+1), :]
                    cv.imwrite(output_path + name + f"_{i}" + ".png", cropped_img)
                    i = i+1


if __name__ == "__main__":
    image_dir = "data/raw/"  # 原始图像路径
    output_path = "data/valid/"  # 保存图像路径
    size = 224  # 裁剪后的尺寸
    crop_image(image_dir, output_path, size)

上述代码中存在一个问题,即会舍弃无法整除的边缘部分图像。由于本数据集中道路主要集中在中心部位,所以舍弃边缘对于结果影响很少。不过我还是建议将上述代码进行优化,将边缘部分进行向左或向上补充以构成224×224大小的图像。

2. 自建数据集

主要通过torchvision.datasetsImageFolder函数搭建数据集,数据集在文件中的路径展示如下:

    root
    ├──Train
    |   ├──normal
    |   |   ├──****.png
    |   |   ├──****.png
    |   |   └──...
    |   └──potholes
    |       ├──****.png
    |       ├──****.png
    |       └──...
    └──Test  
        ├──normal
        |   ├──****.png
        |   ├──****.png
        |   └──...
        └──potholes
            ├──****.png
            ├──****.png
            └──...

PyTorch提供的函数非常方便,我们只需按照以上格式进行排列。ImageFolder函数将其按顺序打上标签,往往我们会将ImageFolder函数与DataLoader一起使用(后续代码有展示)。

二、DCGAN

话不多说,直接附上DCGAN完整代码,对网络结构感兴趣的可以自行学习:

import os
import torch
import argparse
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms

from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable


parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=224, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=20, help="interval between image sampling")
opt = parser.parse_args()
print(opt)

cuda = True if torch.cuda.is_available() else False


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Initialize dataset
root = "../../data/road/normal"
transform = transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        )

if root == "../../data/road/normal":
    path = "normal"
    os.makedirs(path, exist_ok=True)
elif root == "../../data/road/potholes":
    path = "potholes"
    os.makedirs(path, exist_ok=True)


# Configure data loader
train_data = datasets.ImageFolder(root, transform=transform)
dataloader = torch.utils.data.DataLoader(
    dataset=train_data,
    batch_size=opt.batch_size,
    shuffle=True
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        #  Train Generator
        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        #  Train Discriminator
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[1], "%s/%d.png" % (path, batches_done), nrow=1, normalize=True)

三、MobileNet V3

采用MobileNet V3的原因是其为一个轻量化网络架构,满足题目中要求的轻量化和快速性。同时,也因为我个人显卡是RTX2060,采用近期很火的ViT模型难以实现。该网络是于2019年发布的,目前已经完全开源了,可以很轻松就找到网络代码,其中包含了Large和Small两个版本:

'''MobileNetV3 in PyTorch.
See the paper "Inverted Residuals and Linear Bottlenecks:
Mobile Networks for Classification, Detection and Segmentation" for more details.
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init


class hswish(nn.Module):
    def forward(self, x):
        out = x * F.relu6(x + 3, inplace=True) / 6
        return out


class hsigmoid(nn.Module):
    def forward(self, x):
        out = F.relu6(x + 3, inplace=True) / 6
        return out


class SeModule(nn.Module):
    def __init__(self, in_size, reduction=4):
        super(SeModule, self).__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(in_size // reduction),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(in_size),
            hsigmoid()
        )

    def forward(self, x):
        return x * self.se(x)


class Block(nn.Module):
    '''expand + depthwise + pointwise'''
    def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
        super(Block, self).__init__()
        self.stride = stride
        self.se = semodule

        self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(expand_size)
        self.nolinear1 = nolinear
        self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
        self.bn2 = nn.BatchNorm2d(expand_size)
        self.nolinear2 = nolinear
        self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_size)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_size != out_size:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_size),
            )

    def forward(self, x):
        out = self.nolinear1(self.bn1(self.conv1(x)))
        out = self.nolinear2(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        if self.se != None:
            out = self.se(out)
        out = out + self.shortcut(x) if self.stride==1 else out
        return out


class MobileNetV3_Large(nn.Module):
    def __init__(self, num_classes=1000):
        super(MobileNetV3_Large, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1 = hswish()

        self.bneck = nn.Sequential(
            Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
            Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
            Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
            Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
            Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
            Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
            Block(3, 40, 240, 80, hswish(), None, 2),
            Block(3, 80, 200, 80, hswish(), None, 1),
            Block(3, 80, 184, 80, hswish(), None, 1),
            Block(3, 80, 184, 80, hswish(), None, 1),
            Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
            Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
            Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
            Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
            Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
        )

        self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(960)
        self.hs2 = hswish()
        self.linear3 = nn.Linear(960, 1280)
        self.bn3 = nn.BatchNorm1d(1280)
        self.hs3 = hswish()
        self.linear4 = nn.Linear(1280, num_classes)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        out = self.hs1(self.bn1(self.conv1(x)))
        out = self.bneck(out)
        out = self.hs2(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 7)
        out = out.view(out.size(0), -1)
        out = self.hs3(self.bn3(self.linear3(out)))
        out = self.linear4(out)
        return out


class MobileNetV3_Small(nn.Module):
    def __init__(self, num_classes=1000):
        super(MobileNetV3_Small, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1 = hswish()

        self.bneck = nn.Sequential(
            Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
            Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
            Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
            Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
            Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
            Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
            Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
            Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
            Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
            Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
            Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
        )

        self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(576)
        self.hs2 = hswish()
        self.linear3 = nn.Linear(576, 1280)
        self.bn3 = nn.BatchNorm1d(1280)
        self.hs3 = hswish()
        self.linear4 = nn.Linear(1280, num_classes)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        out = self.hs1(self.bn1(self.conv1(x)))
        out = self.bneck(out)
        out = self.hs2(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 7)
        out = out.view(out.size(0), -1)
        out = self.hs3(self.bn3(self.linear3(out)))
        out = self.linear4(out)
        return out


def test():
    net = MobileNetV3_Small()
    x = torch.randn(2, 3, 224, 224)
    y = net(x)
    print(y.size())

三、模型运行与预测

1. main.py

将原始的训练集与DCGAN生成的图像结合构成新的训练集,对MobileNet V3模型进行训练:

import time
import torch
import pandas
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from mobilenetv3 import MobileNetV3_Small, MobileNetV3_Large


# 计算精度
def compute_accuracy(model, data_loader, device):

    correct = 0
    total = 0
    for i, (data, label) in enumerate(data_loader):
        data = data.to(device)
        label = label.to(device)
        out = model(data).argmax(dim=1)
        total += data.size(0)
        correct += torch.eq(out, label).sum()

    return correct/total*100


if __name__ == '__main__':

    epochs = 30  # 定义迭代次数
    batch_size = 16  # 定义每一批图像个数
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # 有GPU则使用GPU,否则使用CPU
    model = MobileNetV3_Large(num_classes=2).to(device)  # 实例化模型并传入GPU(CPU)中

    criterion = nn.CrossEntropyLoss()  # 重新定义交叉熵损失为criterion
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 采用优化器为Adam,学习率lr=0.01

    # Initialize dataset
    train_root = "data_1/train"
    test_root = "data_1/test"
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
    )

    # Configure data loader
    train_data = datasets.ImageFolder(train_root, transform=transform)
    test_data = datasets.ImageFolder(test_root, transform=transform)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )

    start_time = time.time()
    train_accuracy_list, test_accuracy_list = [], []
    filename = 'best_model_1.pth'
    best_accuracy = 0

    for epoch in range(epochs):
        # 开始训练
        model.train()
        for batch_index, (features, targets) in enumerate(train_loader):
            features = features.to(device)  # 将图像传入GPU(CPU)
            targets = targets.to(device)  # 将标签传入GPU(CPU)
            output = model(features)  # 进行预测
            loss = criterion(output, targets)  # 计算损失

            optimizer.zero_grad()  # 初始化梯度
            loss.backward()  # 反向传播
            optimizer.step()  # 参数更新

            # 打印损失值Loss
            if not batch_index % 100:
                print(f'Epoch: {epoch+1:03d}/{epochs:03d} | '
                      f'Batch: {batch_index:04d}/{len(train_loader):04d} | '
                      f'Loss: {loss:.4f}')

        # 模型测试
        model.eval()
        with torch.set_grad_enabled(False):  # 取消梯度更新
            train_accuracy = compute_accuracy(model=model, data_loader=train_loader, device=device)  # 计算训练集精度
            test_accuracy = compute_accuracy(model=model, data_loader=test_loader, device=device)  # 计算测试集精度
            train_accuracy_list.append(train_accuracy.tolist())  # 存储测试集精度
            test_accuracy_list.append(test_accuracy.tolist())  # 存储训练集精度

            # 保存训练最好的模型
            if best_accuracy < test_accuracy:
                best_accuracy = test_accuracy
                state = {
                    'best_accuracy': best_accuracy,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                torch.save(state, filename)

            # 打印精度Acc
            print(f'Epoch: {epoch+1:04d}/{epochs:04d} | '
                  f'Train Acc: {train_accuracy:.2f}% | '
                  f'Test Acc: {test_accuracy:.2f}%')

        # 打印时间Time
        elapsed = (time.time() - start_time)/60
        print(f'Time Elapsed: {elapsed:.2f}min')

    # 绘制整个过程精度变化
    plt.plot(range(epochs), train_accuracy_list, color='r')
    plt.plot(range(epochs), test_accuracy_list, color='b')
    plt.legend(['Training accuracy', 'Testing accuracy'], loc='upper left')
    plt.savefig('data_1.png')
    plt.show()

    # 打印到excel表格中
    df_train = pandas.DataFrame(train_accuracy_list, columns=['train'])
    df_test = pandas.DataFrame(test_accuracy_list, columns=['test'])
    df_train.to_excel("train-1.xlsx", index=False)
    df_test.to_excel("test-1.xlsx", index=False)

我训练了Large和Small两个版本,因为我的分类任务是进行了判断(可见伪代码)。具体情况可以根据自身的需求选择两个版本中的一个!


2. valid.py

需要根据自己的任务对valid.py函数中的代码进行一定的修改:

import os
import torch
from PIL import Image
from mobilenetv3 import MobileNetV3_Small, MobileNetV3_Large

root_dir = 'data/valid/'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model1 = MobileNetV3_Large(num_classes=2).to(device)
model2 = MobileNetV3_Small(num_classes=2).to(device)
state1 = torch.load('best_model_1.pth')  # 加载模型
state2 = torch.load('best_model_2.pth')  # 加载模型
model1.load_state_dict(state1['state_dict'])  # 加载模型参数
model2.load_state_dict(state2['state_dict'])  # 加载模型参数

for filename in os.listdir(root_dir):
    name, _ = filename.split('_')
    file_path = os.path.join(root_dir, filename)

    model1.eval()
    image = Image.open(file_path)
    image = image.to(device)

    with torch.no_grad():
        pre = model1(image)
    _, predicted = torch.max(pre, 1)
    Index = predicted[0]
    if Index is 1:  # 是道路
        model2.eval()

        with torch.no_grad():
            pre = model2(image)
        _, predicted = torch.max(pre, 1)
        if predicted is 1:  # 存在坑洼
            print('potholes')
        else:
            print('normal')
    else:
        print('normal')

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值