自己收集数据集训练分类器

网络效果

这是一个简单的分类器网络,主干网络使用的是resnet18,我只收集了4个类,下面是网络识别的效果
在这里插入图片描述
下面我来介绍具体是怎么做的

Step1:数据收集

首先使用爬虫收集数据集,爬虫代码和资源我都找好了,在你的项目根目录需要有4个文件,分别是 getImage.pydir.txt, image.txt, test_image.txt,下面我来逐个介绍
首先是 getImage.py ,这是一个爬虫,将图片保存至硬盘中,他会读取三个txt文件,其中 dir.txt是类别信息,爬虫会根据类别建立文件夹,每个类建立一个文件夹,然后是image.txt,这里面保存的是训练集的url,分类名和下载的图片数量,格式为url-分类名-数量test_image.txt格式和image.txt一样,只不过它是测试集的图片
下面我把各个文件的内容放着,你只需要安装好环境,在项目的根目录运行python脚本即可

getImage.py

import requests
import os
import time
# 读取dir.txt文件,创建文件夹
def makedir(prefix=None):
    with open("dir.txt","r",encoding="utf-8") as f:
        dirs = f.readlines()
        if prefix!=None:
            filepath = prefix
        else:
            filepath = ""
        for i in range(len(dirs)):
            dir = dirs[i].split("\n")[0]
            if not os.path.exists(filepath+dir):
                os.makedirs(filepath+dir)
                print(filepath+dir+"  文件夹成功创建")
            else:
                print(filepath + dir + "  文件夹已经存在,无需创建")




# 提供url和下载地址下载图片
# mod变量用于控制文件名的写法,imageNum用来指定需要下载的图片数量
def saveImage(fileurl, filepath, mod="tujigu", imageNum=1):
    # 检查文件夹是否存在
    if not os.path.exists(filepath):
        print(filepath+" 文件夹不存在")
        return
    filename = ""
    number = 1
    for i in range(imageNum):
        if (mod=="tujigu"):
            filename = fileurl.split("/")[-2]+"_"+str(number) + ".jpg"
            filename = filepath+filename
            # 检查文件是否存在
            if os.path.exists(filename):
                print(filename+" 已经存在,无需下载")
                number += 1
                continue
            # 下载图片
            imageurl = fileurl.split("1.jpg")[0] + str(number) + ".jpg"
            image = requests.get(imageurl)
            if(image.status_code!=200):
                print(imageurl+" 遇到连接错误,代码:"+str(image.status_code))
                return
        with open(filename, "wb") as f:
            f.write(image.content)
            print(filename+"下载成功")
        number += 1
        time.sleep(0.1)


# 读取image.txt文件夹,并且通过下载图片到相应的文件夹
def download(prefix=None, mod="train"):
    if(mod=="train"):
        filename = "image.txt"
    elif(mod=="test"):
        filename = "test_image.txt"
    with open(filename, "r", encoding="utf-8") as f:
        urls = f.readlines()
        if(prefix!=None):
            filepath = prefix
        else:
            filepath = ""
        for i in range(len(urls)):
            print("获取信息:"+urls[i])
            url, classname, imageNum = urls[i].split("\n")[0].split("-")
            saveImage(url, filepath+classname+"/", imageNum=int(imageNum))

def getstart(prefix=None):
    # 首先制作下载训练集
    print("开始下载训练集")
    makedir(prefix+"train/")
    download(prefix+"train/")
    # 接着下载测试集
    print("开始下载测试集")
    makedir(prefix+"test/")
    download(prefix+"test/")


if __name__ == "__main__":
    getstart("data/")

dir.txt

OL
女仆
校服
旗袍
其他

image.txt

https://lns.hywly.com/a/1/35593/1.jpg-OL-20
https://lns.hywly.com/a/1/35569/1.jpg-OL-20
https://lns.hywly.com/a/1/34341/1.jpg-OL-10
https://lns.hywly.com/a/1/34333/1.jpg-OL-10
https://lns.hywly.com/a/1/34498/1.jpg-OL-10
https://lns.hywly.com/a/1/34479/1.jpg-OL-10
https://lns.hywly.com/a/1/34335/1.jpg-OL-10
https://lns.hywly.com/a/1/34310/1.jpg-OL-10
https://lns.hywly.com/a/1/35927/1.jpg-女仆-10
https://lns.hywly.com/a/1/35351/1.jpg-女仆-20
https://lns.hywly.com/a/1/35419/1.jpg-女仆-20
https://lns.hywly.com/a/1/35876/1.jpg-女仆-20
https://lns.hywly.com/a/1/35874/1.jpg-女仆-20
https://lns.hywly.com/a/1/35826/1.jpg-女仆-20
https://lns.hywly.com/a/1/35838/1.jpg-校服-10
https://lns.hywly.com/a/1/35819/1.jpg-校服-10
https://lns.hywly.com/a/1/35620/1.jpg-校服-10
https://lns.hywly.com/a/1/35035/1.jpg-校服-10
https://lns.hywly.com/a/1/34118/1.jpg-校服-10
https://lns.hywly.com/a/1/34065/1.jpg-校服-10
https://lns.hywly.com/a/1/33300/1.jpg-校服-10
https://lns.hywly.com/a/1/33286/1.jpg-校服-10
https://lns.hywly.com/a/1/33013/1.jpg-校服-10
https://lns.hywly.com/a/1/32954/1.jpg-校服-10
https://lns.hywly.com/a/1/35133/1.jpg-校服-10
https://lns.hywly.com/a/1/35122/1.jpg-校服-10
https://lns.hywly.com/a/1/35155/1.jpg-校服-10
https://lns.hywly.com/a/1/35114/1.jpg-校服-10
https://lns.hywly.com/a/1/35026/1.jpg-校服-10
https://lns.hywly.com/a/1/35036/1.jpg-校服-10
https://lns.hywly.com/a/1/34435/1.jpg-校服-10
https://lns.hywly.com/a/1/34109/1.jpg-校服-10
https://lns.hywly.com/a/1/10064/1.jpg-旗袍-10
https://lns.hywly.com/a/1/35939/1.jpg-旗袍-10
https://lns.hywly.com/a/1/35059/1.jpg-旗袍-10
https://lns.hywly.com/a/1/34987/1.jpg-旗袍-10
https://lns.hywly.com/a/1/34984/1.jpg-旗袍-10
https://lns.hywly.com/a/1/34837/1.jpg-旗袍-10
https://lns.hywly.com/a/1/34650/1.jpg-旗袍-10
https://lns.hywly.com/a/1/34669/1.jpg-旗袍-10
https://lns.hywly.com/a/1/34394/1.jpg-旗袍-10
https://lns.hywly.com/a/1/34761/1.jpg-旗袍-10
https://lns.hywly.com/a/1/34396/1.jpg-旗袍-10
https://lns.hywly.com/a/1/34292/1.jpg-旗袍-10
https://lns.hywly.com/a/1/34317/1.jpg-旗袍-10
https://lns.hywly.com/a/1/34166/1.jpg-旗袍-10
https://lns.hywly.com/a/1/33865/1.jpg-旗袍-10
https://lns.hywly.com/a/1/33727/1.jpg-旗袍-10
https://lns.hywly.com/a/1/33513/1.jpg-旗袍-10
https://lns.hywly.com/a/1/33456/1.jpg-旗袍-10
https://lns.hywly.com/a/1/33577/1.jpg-旗袍-10
https://lns.hywly.com/a/1/33696/1.jpg-旗袍-10
https://lns.hywly.com/a/1/33845/1.jpg-旗袍-10

test_image.txt

https://lns.hywly.com/a/1/22791/1.jpg-校服-10
https://lns.hywly.com/a/1/22741/1.jpg-校服-10
https://lns.hywly.com/a/1/22724/1.jpg-校服-10
https://lns.hywly.com/a/1/22040/1.jpg-校服-10
https://lns.hywly.com/a/1/22013/1.jpg-校服-10
https://lns.hywly.com/a/1/7429/1.jpg-旗袍-10
https://lns.hywly.com/a/1/7295/1.jpg-旗袍-10
https://lns.hywly.com/a/1/5579/1.jpg-旗袍-10
https://lns.hywly.com/a/1/5154/1.jpg-旗袍-10
https://lns.hywly.com/a/1/2679/1.jpg-旗袍-10
https://lns.hywly.com/a/1/35760/1.jpg-女仆-20
https://lns.hywly.com/a/1/35746/1.jpg-女仆-20
https://lns.hywly.com/a/1/35773/1.jpg-女仆-10
https://lns.hywly.com/a/1/33467/1.jpg-OL-10
https://lns.hywly.com/a/1/33459/1.jpg-OL-10
https://lns.hywly.com/a/1/33482/1.jpg-OL-10
https://lns.hywly.com/a/1/33480/1.jpg-OL-10
https://lns.hywly.com/a/1/33579/1.jpg-OL-10

脚本会在项目根目录创建data文件夹,文件夹里面有traintest两个文件夹,分别存放训练数据和测试数据,每个文件夹里面有和类别数相同的文件夹,里面存放的就是各个类别的图片。

Step2:网络结构

网络部分直接使用resnet,代码是抄的csdn上一位大佬的。
network.py

import torch
from torch import nn
import torch.nn.functional as f




# resnet网络
# 用于ResNet18和34的残差块,用的是2个3x3的卷积
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        # 经过处理后的x要与x的维度相同(尺寸和深度)
        # 如果不相同,需要添加卷积+BN来变换为同一维度
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = f.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = f.relu(out)
        return out


# 用于ResNet50,101和152的残差块,用的是1x1+3x3+1x1的卷积
class Bottleneck(nn.Module):
    # 前面1x1和3x3卷积的filter个数相等,最后1x1卷积是其expansion倍
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = f.relu(self.bn1(self.conv1(x)))
        out = f.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = f.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, linearNum, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(linearNum * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = f.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = f.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(num_class, linearNum):
    return ResNet(BasicBlock, [2, 2, 2, 2],linearNum=linearNum,num_classes=num_class)


def ResNet34(num_class, linearNum):
    return ResNet(BasicBlock, [3, 4, 6, 3],linearNum=linearNum, num_classes=num_class)


def ResNet50(num_class, linearNum):
    return ResNet(Bottleneck, [3, 4, 6, 3],linearNum=linearNum, num_classes=num_class)


def ResNet101(num_class, linearNum):
    return ResNet(Bottleneck, [3, 4, 23, 3],linearNum=linearNum, num_classes=num_class)


def ResNet152(num_class, linearNum):
    return ResNet(Bottleneck, [3, 8, 36, 3],linearNum=linearNum, num_classes=num_class)


def test(num_class, linearNum):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = ResNet18(num_class, linearNum).to(device)
    y = net(torch.randn(1, 3, 300, 200).to(device))
    print(y.size())

if __name__ == "__main__":
    test(5, 27648)

这里需要说明一点,就是网络有一个全连接层,对于不同尺寸的输入,全连接层结点数量是不同的,你需要根据具体的输入大小自己指定这个数量。方法也很简单,你输入一个大小的tensor,如果报错了,你就把错误中相应的数量填入就行

Step3:数据加载

接着就是数据的加载工作,分别加载训练数据,训练标签和测试数据,测试标签
dataload.py

import torch
import torchvision.transforms as tfs
from torch.utils.data import DataLoader
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt




# 超参数设定
BATCH_SIZE = 1
EPOCHES = 40
LR = 0.01
WIDTH = 200
HEIGHT = 300

# 类别标签
classes = {
    "OL":0,
    "女仆":1,
    "校服":2,
    "旗袍":3,
    "其他":4
}

class_name = ["OL", "女仆", "校服", "旗袍", "其他"]

# 加载数据集
def read_image(mode="train"):
    data_root = "./data/train/"
    data_root_test = "./data/test/"
    image_train = []
    image_test = []
    label_train = []
    label_test = []
    with open("dir.txt", "r", encoding="utf-8") as f:
        dirs = f.readlines()
    for i in range(len(dirs)):
        filepath_train = data_root + dirs[i].split("\n")[0]
        filepath_test = data_root_test + dirs[i].split("\n")[0]
        image_name = os.listdir(filepath_train)
        for j in range(len(image_name)):
            image_train.append(filepath_train+"/"+image_name[j])
            label_train.append(classes[dirs[i].split("\n")[0]])
        image_name = os.listdir(filepath_test)
        for j in range(len(image_name)):
            image_test.append(filepath_test+"/"+image_name[j])
            label_test.append(classes[dirs[i].split("\n")[0]])
    if(mode=="train"):
        return image_train, label_train
    elif(mode=="test"):
        return image_test, label_test


def image_transforms(data, height, width):
    data = data.resize((width, height))
    # 将数据转换成tensor,并且做标准化处理
    im_tfs = tfs.Compose([
        tfs.ToTensor(),
        tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    data = im_tfs(data)
    return data

# 加载数据集的类
class dataset(torch.utils.data.Dataset):

    # 构造函数
    def __init__(self,mode="train",  width=200, height=300, transforms=image_transforms):
        self.height = height
        self.width = width
        self.transforms = transforms
        self.mode = mode
        image_list, label_list = read_image(mode=self.mode)
        self.image_list = image_list
        self.label_list = label_list


    # 重载getitem函数,使类可以迭代
    def __getitem__(self, idx):
        img = Image.open(self.image_list[idx])
        img = self.transforms(img, self.height, self.width)
        label = np.array(self.label_list[idx])
        return img, label

    def __len__(self):
        return len(self.image_list)

# 加载训练集和测试集
def loader(width=WIDTH, height=HEIGHT):
    train_data = DataLoader(dataset(mode="train", width=WIDTH, height=HEIGHT), batch_size=BATCH_SIZE, shuffle=True)
    test_data = DataLoader(dataset(mode="test", width=WIDTH, height=HEIGHT), batch_size=BATCH_SIZE)
    return train_data, test_data


if __name__ == "__main__":
    train_data, test_data = loader()
    for img, label in train_data:
        print(img.shape)
        print(label.shape)
        break

    for img, label in test_data:
        print(img.shape)
        print(label.shape)
        break


网络的一些超参数也在这个脚本中定义,比如batchsize,因为我的GPU内存比较小,所以batchsize只取1,如果你的内存够大,你可以将这个值设置大一点

Step4:网路训练

这个部分就是网络训练了,我们使用交叉熵损失函数和Adam优化器,需要注意的是,损失函数的输入标签类型是长整型,我们在计算损失函数前需要将标签进行类型转换
train.py

import torch
from torch import nn
import torch.nn.functional as f
import torchvision
import torchvision.transforms as tfs
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision.models as models
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime

import network
import dataload

"""
certion = nn.CrossEntropyLoss()

a = np.array([1,0])
a = torch.from_numpy(a).long()
b = torch.rand(2, 3)
print(b)
loss = certion(b, a)
print(loss)
"""



# 返回损失函数,网络,和优化器
def init(mode=1):
    # 使用交叉熵损失函数
    if(mode==1):
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")
    criterion = nn.CrossEntropyLoss()
    net = network.ResNet18(5, 27648).to(device)
    basic_optim = torch.optim.Adam(net.parameters(), lr=dataload.LR, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    return net, criterion, basic_optim



def train(net, criterion, basic_optim, mode=1):
    if (mode == 1):
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    train_loss = []
    test_loss = []
    train_acc = []
    test_acc = []



    train_data, test_data = dataload.loader()

    for epoch in range(dataload.EPOCHES):
        _train_num_acc = 0
        _test_num_acc = 0

        _train_loss = 0
        _test_loss = 0

        prev_time = datetime.now()
        # 训练
        for img, lab in train_data:
            img = img.to(device)
            lab = lab.long().to(device)

            output = net(img)
            loss = criterion(output, lab)
            basic_optim.zero_grad()
            loss.backward()
            basic_optim.step()

            _train_loss += loss.item()

            _, pred = output.max(1)
            _train_num_acc += (pred==lab).sum().item()


        # 测试
        for img, lab in test_data:
            img = img.to(device)
            lab = lab.long().to(device)

            output = net(img)
            loss = criterion(output, lab)

            _test_loss += loss.item()

            _, pred = output.max(1)
            _test_num_acc += (pred==lab).sum().item()


        train_loss.append(_train_loss/len(train_data))
        test_loss.append(_test_loss/len(test_data))
        train_acc.append(_train_num_acc/len(train_data))
        test_acc.append(_test_num_acc/len(test_data))

        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)

        print("epoch "+str(epoch)+": train loss:"+str(_train_loss/len(train_data))+" train acc:"+str(_train_num_acc/len(train_data))+\
              " test loss:"+ str(_test_loss/len(test_data)) + " test acc:"+ str(_test_num_acc/len(test_data))+\
              " use time:" + str(h)+":" + str(m) + ":"+ str(s))

    return train_loss, train_acc, test_loss, test_acc


# 绘图
def draw(train_loss, train_acc, test_loss, test_acc):
    x = []
    for i in range(dataload.EPOCHES):
        x.append(i+1)
    plt.plot(x, train_loss, label="train loss")
    plt.plot(x, test_loss, label="test loss")
    plt.grid()
    plt.legend()
    plt.title("loss")
    if not os.path.exists("./img"):
        os.mkdir("./img")
    plt.savefig("./img/loss.jpg")
    # plt.show()

    plt.clf()
    plt.plot(x, train_acc, label="train acc")
    plt.plot(x, test_acc, label="test_acc")
    plt.grid()
    plt.legend()
    plt.title("acc")
    plt.savefig("./img/acc.jpg")
    # plt.show()

def savemodel(net, path="./model/resnet34-1.pth"):
    if not os.path.exists("./model"):
        os.mkdir("./model")
    torch.save(net.state_dict(), path)

if __name__ == "__main__":
    net, criterion, optim = init()
    train_loss, train_acc, test_loss, test_acc = train(net, criterion, optim)
    draw(train_loss, train_acc, test_loss, test_acc)
    savemodel(net)

我的GPU跑一轮大概需要5分钟左右,根据实验,大概25轮的时候网络就收敛了,训练数据和测试数据的准确率都能达到100%,下面是网络的准确率和损失值曲线
在这里插入图片描述
在这里插入图片描述

Step5:测试网络

我从百度找了四张图片进行测试,结果就是文章开头的那张图片。
classifer.py

import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False

import network
import dataload


num_class = 5
linearNum = 27648


if __name__ == "__main__":
    net = network.ResNet18(num_class, linearNum)

    path = "./model/resnet34-1.pth"
    net.load_state_dict(torch.load(path))

    img = ["a.jpg", "b.jpg", "c.jpg", "d.jpg"]
    for i in range(len(img)):
        im = Image.open(img[i])
        im1 = dataload.image_transforms(im, dataload.HEIGHT, dataload.WIDTH)
        out = net(im1.unsqueeze(0))
        _, pred = out.max(1)
        plt.subplot(2,2, i+1), plt.imshow(im.resize((200,300))), plt.title(dataload.class_name[pred.numpy()[0]]),plt.axis("off")

    plt.savefig("./img/pred.jpg")
    plt.show()

下面我们来对网络的特征提取进行可视化,我们将下面这张图片送入网络
在这里插入图片描述
使用下面的代码提取五个阶段的特征图并将他们可视化
vision.py

import torch
from torch import nn
import torch.nn.functional as f

import matplotlib.pyplot as plt
import network
import dataload
from PIL import Image




if __name__ == "__main__":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu()")
    im = Image.open("./data/train/OL/34333_9.jpg")
    im = dataload.image_transforms(im, dataload.HEIGHT, dataload.WIDTH)
    im = im.to(device)

    net = network.ResNet18(5, 27648).to(device)


    stage1 = nn.Sequential(*list(net.children())[:-5])
    stage2 = nn.Sequential(*list(net.children())[:-4])
    stage3 = nn.Sequential(*list(net.children())[:-3])
    stage4 = nn.Sequential(*list(net.children())[:-2])
    stage5 = nn.Sequential(*list(net.children())[:-1])

    out1 = stage1(im.unsqueeze(0))
    out2 = stage2(im.unsqueeze(0))
    out3 = stage3(im.unsqueeze(0))
    out4 = stage4(im.unsqueeze(0))
    out5 = stage5(im.unsqueeze(0))


    output = []
    output.append(out1)
    output.append(out2)
    output.append(out3)
    output.append(out4)
    output.append(out5)

    print("out1:" + str(out1.shape))
    print("out2:" + str(out2.shape))
    print("out3:" + str(out3.shape))
    print("out4:" + str(out4.shape))
    print("out5:" + str(out5.shape))


    for i in range(5):
        for j in range(9):
            plt.subplot(3,3,j+1)
            plt.imshow(output[i].squeeze(0)[j].cpu().detach().numpy())
            plt.axis("off")
        plt.savefig("./img/stage"+str(i+1)+".jpg")
        plt.clf()

运行的结果如下
stage1.jpg
在这里插入图片描述

stage2.jpg
在这里插入图片描述

stage3.jpg
在这里插入图片描述

stage4.jpg
在这里插入图片描述

stage5.jpg
在这里插入图片描述

输入下面的图片
在这里插入图片描述
特征图如下
stage1.jpg
在这里插入图片描述

stage2.jpg
在这里插入图片描述

stage3.jpg
在这里插入图片描述

stage4.jpg
在这里插入图片描述

stage5.jpg
在这里插入图片描述
好啦,以上就是本篇文章的全部内容啦

  • 4
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值