图片分类任务

    前言:这是一个关于图像分类的文章,数据集连接是一个Cervical Cancer largest dataset (SipakMed)。更换数据集也可以完成其他分类任务。全部资源连接。部分代码如下:

一、数据集处理

1、数据集增强

    感觉这个代码比较垃圾,没有进行转换为tensor而且没有进行归一化。所以没有注释。

"""
这个代码定义一个DataAugmentation类,其中包含各种图像增强的方法。

这个类的初始化函数接受两个参数:
root_dir和output_dir,分别表示原始图像所在目录和增强后图像保存的目录。

augmentation_of_image方法接受两个参数:
test_image和output_path,分别表示要增强的图像路径和增强后的图像保存目录。

在这个方法中定义了各种数据集增强操作,包括图像旋转、缩放、平移、剪切......
并使用imgaug库中的增强器实现这些操作,最后将生成的增强图像保存到指定的目录中
"""

import cv2
import imgaug
import imageio
import os
import numpy as np
from imgaug import augmenters as iaa


class DataAugmentation:

	def __init__(self, root_dir="",output_dir=""):
		self.root_dir = root_dir
		self.output_dir = output_dir
		print("Instance of the DataAugmentation class created")

	def augmentation_of_image(self, test_image, output_path):
		self.test_image = test_image;
		self.output_path = output_path;
		#define the Augmenters


		#properties: A range of values signifies that one of these numbers is randmoly chosen for every augmentation for every batch

		# Apply affine transformations to each image.
		rotate = iaa.Affine(rotate=(-90, 90));
		scale = iaa.Affine(scale={"x": (0.5, 0.9), "y": (0.5,0.9)}); 
		translation = iaa.Affine(translate_percent={"x": (-0.15, 0.15), "y": (-0.15, 0.15)});
		shear = iaa.Affine(shear=(-2, 2));
		zoom = iaa.PerspectiveTransform(scale=(0.01, 0.15), keep_size=True)
		h_flip = iaa.Fliplr(1.0);
		v_flip = iaa.Flipud(1.0);
		padding = iaa.KeepSizeByResize(iaa.CropAndPad(percent=(0.05, 0.25)))


		#More augmentations
		blur = iaa.GaussianBlur(sigma=(0, 1.22))
		contrast = iaa.contrast.LinearContrast((0.75, 1.5));
		contrast_channels = iaa.LinearContrast((0.75, 1.5), per_channel=True)
		sharpen = iaa.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5));
		gauss_noise = iaa.AdditiveGaussianNoise(scale=0.111*255, per_channel=True)
		laplace_noise = iaa.AdditiveLaplaceNoise(scale=(0, 0.111*255))


		#Brightness 
		brightness = iaa.Multiply((0.35, 1.65)) #change brightness between 35% or 165% of the original image
		brightness_channels = iaa.Multiply((0.5, 1.5), per_channel=0.75) # change birghtness for 25% of images.For the remaining 75%, change it, but also channel-wise.

		#CHANNELS (RGB)=(Red,Green,Blue)
		red =iaa.WithChannels(0, iaa.Add((10, 100))) #increase each Red-pixels value within the range 10-100
		red_rot = iaa.WithChannels(0, iaa.Affine(rotate=(0, 45))) #rotate each image's red channel by 0-45 degrees
		green= iaa.WithChannels(1, iaa.Add((10, 100)))
		green_rot=iaa.WithChannels(1, iaa.Affine(rotate=(0, 45)))
		blue=iaa.WithChannels(2, iaa.Add((10, 100)))#increase each Blue-pixels value within the range 10-100
		blue_rot=iaa.WithChannels(2, iaa.Affine(rotate=(0, 45))) #rotate each image's blue channel by 0-45 degrees

		#colors
		channel_shuffle =iaa.ChannelShuffle(1.0);
		grayscale = iaa.Grayscale(1.0)
		hue_n_saturation = iaa.MultiplyHueAndSaturation((0.5, 1.5), per_channel=True) #change hue and saturation with this range of values for different values 
		add_hue_saturation = iaa.AddToHueAndSaturation((-50, 50), per_channel=True) #add more hue and saturation to its pixels
		#Quantize colors using k-Means clustering
		kmeans_color = iaa.KMeansColorQuantization(n_colors=(4, 16)) #quantizes to k means 4 to 16 colors (randomly chosen). Quantizes colors up to 16 colors

		#Alpha Blending 
		blend =iaa.BlendAlphaElementwise((0, 1.0), iaa.Grayscale((0,1.0))) ; #blend depending on which value is greater

		#Contrast augmentors
		clahe = iaa.CLAHE(tile_grid_size_px=((3, 21),[0,2,3,4,5,6,7])) #create a clahe contrast augmentor H=(3,21) and W=(0,7)
		histogram = iaa.HistogramEqualization() #performs histogram equalization

		#Augmentation list of metadata augmentors
		OneofRed = iaa.OneOf([red]);
		OneofGreen = iaa.OneOf([green]);
		OneofBlue = iaa.OneOf([blue]);
		contrast_n_shit = iaa.OneOf([contrast, brightness, brightness_channels]);
		SomeAug = iaa.SomeOf(2, [rotate, scale, translation, shear, h_flip, v_flip], random_order=True);
		SomeClahe = iaa.SomeOf(2, [clahe, iaa.CLAHE(clip_limit=(1, 10)), iaa.CLAHE(tile_grid_size_px=(3, 21)), iaa.GammaContrast((0.5, 2.0)), iaa.AllChannelsCLAHE() , iaa.AllChannelsCLAHE(clip_limit=(1, 10), per_channel=True)],random_order=True) #Random selection from clahe augmentors
		edgedetection= iaa.OneOf([iaa.EdgeDetect(alpha=(0, 0.7)), iaa.DirectedEdgeDetect(alpha=(0, 0.7), direction=(0.0, 1.0))]);# Search in some images either for all edges or for directed edges.These edges are then marked in a black and white image and overlayed with the original image using an alpha of 0 to 0.7.
		canny_filter = iaa.OneOf([iaa.Canny(), iaa.Canny(alpha=(0.5, 1.0), sobel_kernel_size=[3, 7])]); #choose one of the 2 canny filter options
		OneofNoise = iaa.OneOf([blur, gauss_noise, laplace_noise])
		Color_1 = iaa.OneOf([channel_shuffle, grayscale, hue_n_saturation, add_hue_saturation, kmeans_color]);
		Color_2 = iaa.OneOf([channel_shuffle, grayscale, hue_n_saturation, add_hue_saturation, kmeans_color]);
		Flip = iaa.OneOf([histogram, v_flip, h_flip]);

		#Define the augmentors used in the DA
		Augmentors= [SomeAug, SomeClahe, SomeClahe, edgedetection,sharpen, canny_filter, OneofRed, OneofGreen, OneofBlue, OneofNoise, Color_1, Color_2, Flip, contrast_n_shit]

        # ------------------------------------------------------------------------------------------------ #
		# 这段代码的作用是使用一组数据增强器(`Augmentors`)对给定的测试图像进行数据增强,并将生成的增强图像保存到指定的输出路径中。
		# 具体来说,它的步骤如下:
		# 1. 使用cv2.imread函数读取给定的测试图像,将其存储在img变量中。
		# 2. 使用np.array函数创建一个大小为14的数组images,其中包含14个元素,每个元素都是img副本。
		# 3. 使用Augmentors[i].augment_images函数将images数组中的所有图像应用第i个数据增强器,生成一组新的增强图像,并将其存储在images_aug变量中。
		# 4. 使用cv2.imwrite函数将images_aug数组中的第i个增强图像保存到指定的输出路径中,并将其命名为原测试图像文件名加上new和i的字符串,以便区分不同的增强图像。
		# 请注意,这段代码中的test_image和output_path变量需要根据实际情况进行设置,以确保读取正确的测试图像并将增强图像保存到正确的输出路径中。
		# 此外,`Augmentors`数组中的每个数据增强器需要在代码的其他地方进行定义和初始化。
		# ------------------------------------------------------------------------------------------------ #
		for i in range(0,14):
			img = cv2.imread(test_image) #read you image
			images = np.array([img for _ in range(14)], dtype=np.uint8)  # 12 is the size of the array that will hold 8 different images
			images_aug = Augmentors[i].augment_images(images)  #alternate between the different augmentors for a test image
			cv2.imwrite(os.path.join(output_path, test_image +"new"+str(i)+'.jpg'), images_aug[i])  #write all changed images


		#implementation - save new DA image
		# imglist = []
		# image = cv2.imread("test.jpg");
		# imglist.append(image);
		# image_aug = SomeClahe.augment_images(imglist);
		# cv2.imwrite("path.join(output_path, augmented_image.jpg")), image_aug[0]);
"""
这段代码定义一个DataAugmention_Extension的类,该类用于数据集。

DataAugmention是一个自定义的类,用于实现数据增强功能。
"""
import os
from sklearn.model_selection import train_test_split
import numpy as np
import glob
import shutil
from DataAugmentation import DataAugmentation #import library from the python file



class DataAugmentation_Extension:

	#root_dir = "/home/cantonioupao/Desktop/SIPakMed"
	#output_dir = "home/cantonioupao/Desktop/SIPakMed/Divided_Dataset"
	def __init__(self, directory=""):
		self.directory = directory
		print("Instance of DataAugmentation_Extension class created")

	def printnow(self, dir):
		print("Just testing that the method calling is working "+ dir)


	def extend_dataset(self,directory):
		#Create an instance of class 
		print("HEY")
		library_augment = DataAugmentation();
		self.directory = directory 
		if not os.path.exists(self.directory):
			print("ERROR! Couldn't find directory!")
		else:
			print("Directory exists")
		for file in os.listdir(directory):            #for any file inside the root directory 
			classes_path = os.path.join(directory, file)  #So for every folder class we create a class directory
			class_files = [name for name in glob.glob(os.path.join(classes_path,'*.bmp'))]  #alternatively we can use the globe as mentioned
			print(class_files); #call augmentation for all class_files
			for i in range(len(class_files)):
				library_augment.augmentation_of_image(class_files[i], classes_path)

    调用上面的代码完成数据集增强:

from DataAugmentation_Extension import DataAugmentation_Extension

# -------------------------------------------------------------- #
# target_directory是用来需要增强的文件夹
# -------------------------------------------------------------- #
target_directory = "E:\code\Cervical_Cancer\concat\data_origin"

datasetda = DataAugmentation_Extension()
datasetda.extend_dataset(target_directory)

2、数据集划分

import os
from sklearn.model_selection import train_test_split
import numpy as np
import glob
import shutil



class DatasetDivision:

	#root_dir = "/home/cantonioupao/Desktop/SIPakMed"
	#output_dir = "home/cantonioupao/Desktop/SIPakMed/Divided_Dataset"
	# -------------------------------------------------------------- #
	# 设置实例的root_dir和output_dir属性。如果没有提供这些属性的值,
	# 则将他们设置为空字符串。
	# 最后他会打印消息表示已经创建了类的实例。
	# -------------------------------------------------------------- #
	def __init__(self, root_dir="",output_dir=""):
		self.root_dir = root_dir
		self.output_dir = output_dir
		print("Instance of the class created")

	# 测试方法调用是否正常工作,接受一个new_dir并打印到控制台上
	def printnow(self, new_dir):
		print("Just testing that the method calling is working"+new_dir)

	# -------------------------------------------------------------- #
	# 首先检测输出目录是否存在,如果存在,则值创建相应的训练集、验证集和测试集目录。
	# 如果输出目录不存在,则创建输出目录和相应的训练集、验证集和测试集目录。
	# -------------------------------------------------------------- #
	def divide_dataset(self, root_dir,output_dir):
		self.root_dir =root_dir
		self.output_dir = output_dir
		if os.path.exists(self.output_dir):
			if not os.path.exists(os.path.join(self.output_dir,'train')):
				os.mkdir(os.path.join(self.output_dir,'train'))  #create the first directory
				os.mkdir(os.path.join(self.output_dir,'val')) # 2nd directory
				os.mkdir(os.path.join(self.output_dir,'test')) #3 directory
		else:
			os.mkdir(self.output_dir)
			os.mkdir(os.path.join(self.output_dir,'train')) #create the first directory
			os.mkdir(os.path.join(self.output_dir, 'val')) # 2nd directory
			os.mkdir(os.path.join(self.output_dir, 'test')) #3 directory
		# Split train/val/test sets
		# -------------------------------------------------------------- #
		# 对数据集根目录进行遍历,并针对每一个文件夹执行以下操作:
		# 1、定义class_path变量,该变量包含当前文件的完整路径
		# 2、定义class_files变量,该变量包含当前文件夹中所有.bmp文件的完整路径。
		# 3、使用train_test_split()函数将class_files随机分为训练集+验证集+测试集
		# 4、再次使用train_test_split()函数将train_and_valid随机分为训练集和验证集
		# 5、定义train_dir、val_dir、和test_dir变量分别对应训练集、验证集、测试集的目录路径
		# 6、如果这些路径都不存在则创建他们。
		# 7、使用shutil.move()函数将训练集、验证集、测试集中的图像文件移动到相应目录中。
		# -------------------------------------------------------------- #
		for file in os.listdir(root_dir):            #for any file inside the root directory 
			classes_path = os.path.join(root_dir, file)  #fSo for every folder class we create a class directory
			class_files = [name for name in glob.glob(os.path.join(classes_path, '*.jpg'))]  #alternatively we can use the globe as mentioned
			train_and_valid, test = train_test_split(class_files, test_size=0.2, random_state=42)  #this signifies that our test dataset will e the 20% of the dataset - sklearn function#
			train, val = train_test_split(train_and_valid, test_size=0.25, random_state=42)  #this signifies that the validation dataset will be 20% of it , leaving 60% for training #

			#Define the training, validation and testing directories that the frame folders will be moved to.
			# 定义训练集、验证集、测试集的目录路径,使用os.path.join()函数将他们与输出目录和当前文件夹名称合并起来。
			train_dir = os.path.join(self.output_dir, 'train',file) #creates the path for Divided_Dataset->train->Dyskeratotic 
			val_dir = os.path.join(self.output_dir, 'val', file) #creates the path for Divided_Dataset->val->Dyskeratotic 
			test_dir = os.path.join(self.output_dir, 'test',file) #creates the path for Divided_Dataset->test->Dyskeratotic 
			# 检查训练集、验证集合测试集是否存在,如果不存在,分别使用os.mkdir函数创建这些目录
			if not os.path.exists(train_dir):
				os.mkdir(train_dir)
			if not os.path.exists(val_dir):
				os.mkdir(val_dir)
			if not os.path.exists(test_dir):
				os.mkdir(test_dir)

			# 使用shulti.move()函数将训练集、验证集和测试集中的图像移动到相应目录中,最后打印信息表示划分完成。
			for frame_folders in train:
				#get only the last directory of the path frame_folders
				frame_folder = os.path.join(root_dir,file,frame_folders)
				shutil.move(frame_folder,train_dir)
			for frame_folders in val:
				frame_folder = os.path.join(root_dir,file,frame_folders)
				shutil.move(frame_folder,val_dir)
			for frame_folders in test:
				frame_folder = os.path.join(root_dir,file,frame_folders)
				shutil.move(frame_folder,test_dir)
			print('Dataset Division finished.')        

    调用上面代码完成划分:

from DatasetDivision import DatasetDivision

# -------------------------------------------------------------- #
# path_dir是经过增强后的数据集文件夹
# out_dir是划分好的数据集文件夹
# -------------------------------------------------------------- #
path_dir= "E:\code\Cervical_Cancer\concat\data_enhancement"
output_dir = "E:\code\Cervical_Cancer\concat\dataset"
#create an instance of the class
datasetdiv1 = DatasetDivision()
datasetdiv1.printnow("The new guy")
datasetdiv1.divide_dataset(path_dir, output_dir)

3、数据集裁剪

    其实数据集裁剪应该在未划分前就该做,没办法下次注意。

"""
因为resnet需要图片大小文224*224,这个程序是实现裁剪。

"""

#提取目录下所有图片,更改尺寸后保存到另一目录
import os
import glob
import cv2
import numpy as np
from PIL import Image

# 将原图像裁剪到224*224像素
def convertjpg(jpgfile, outdir, width=224, height=224):
    img = Image.open(jpgfile)
    try:
        # 将输入图像保存为临时文件'temp.jpg'
        img.save('temp.jpg')
        img_data = cv2.imread('temp.jpg')

        # 删除'temp.jpg'文件夹下的所有文件(没必要这段代码,纯属于害怕数据集太大占内存)
        os.remove('temp.jpg')
        new_img_data = cv2.resize(img_data, (width, height))

        # 使用PIL库的Image.fromarray函数将缩放后的图像数据转换为PIL Image赋值给new_img
        new_img = Image.fromarray(new_img_data)

        # 使用os.path.basename函数获取输入图像文件名,将其与输出文件夹路径拼接,得到输出文件名outfile
        filename = os.path.basename(jpgfile)
        outfile = os.path.join(outdir, filename)

        # 将裁剪后的图像保存到输出文件夹中
        new_img.save(outfile)

    # 这是为了捕捉异常,上面代码出现错误会调转到这部分程序。
    # Exception是一个基类,可以捕捉异常,使用as关键字将Exception类型的异常赋值给变量e。
    except Exception as e:
        print(e)

# 将该文件夹下的所有jpg格式统一进行裁剪
for jpgfile in glob.glob("E:/code/Cervical_Cancer/concat/dataset/val/im_Superficial-Intermediate/*.jpg"):
    # 输出文件夹
    convertjpg(jpgfile, "E:/code/Cervical_Cancer/concat/datasets/val/im_Superficial-Intermediate")

4、重命名

    恩...........这个不重要,主要是看着舒服。而且它也应该在未划分前执行。

# -*- coding:utf8 -*-
import os


class BatchRename():
    """
    批量重命名文件夹中的图片文件
    """

    def __init__(self):
        self.path = "E:\code\Cervical_Cancer\concat\dataset\\val\im_Superficial-Intermediate"  # 表示需要命名处理的文件夹

    def rename(self):
        filelist = os.listdir(self.path)  # 获取文件路径
        total_num = len(filelist)  # 获取文件长度(个数)
        i = 1  # 表示文件的命名是从1开始的
        for item in filelist:
            if item.endswith('.jpg'):  # 转换格式就可以调整为自己需要的格式即可
                src = os.path.join(os.path.abspath(self.path), item)
                # dst = os.path.join(os.path.abspath(self.path), '' + str(i) + '.jpg')
                dst = os.path.join(os.path.abspath(self.path),  format(str(i), '0>3s') + '.jpg')
                # 这种情况下的命名格式为00000.jpg形式,可以自定义格式
                try:
                    os.rename(src, dst)
                    print('converting %s to %s ...' % (src, dst))
                    i = i + 1
                except:
                    continue
        print('total %d to rename & converted %d jpgs' % (total_num, i))


if __name__ == '__main__':
    demo = BatchRename()
    demo.rename()

二、训练

    主要就是这个代码:

"""
新项目记得更换:
            classes_path    = 'model_data/cls_classes.txt'
            backbone        = "resnet50"
            好多都要修改,不写了。

训练时我没有使用val,而是从train中取部分代替,代替后的数据集划分比例为6:2:2
"""
import os

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from nets import get_model_from_name
from utils.callbacks import LossHistory
from utils.dataloader import DataGenerator, detection_collate
from utils.utils import (download_weights, get_classes, get_lr_scheduler,
                         set_optimizer_lr, show_config, weights_init)
from utils.utils_fit import fit_one_epoch

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

if __name__ == "__main__":
    # ---------------------------------------------------- #
    #   是否使用Cuda
    #   没有GPU可以设置成False
    # ---------------------------------------------------- #
    Cuda            = True
    # --------------------------------------------------------------------- #
    #   distributed     用于指定是否使用单机多卡分布式运行
    #                   终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。
    #                   Windows系统下默认使用DP模式调用所有显卡,不支持DDP。
    #   DP模式:
    #       设置            distributed = False
    #       在终端中输入    CUDA_VISIBLE_DEVICES=0,1 python train.py
    #   DDP模式:
    #       设置            distributed = True
    #       在终端中输入    CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
    # --------------------------------------------------------------------- #
    distributed     = False
    # --------------------------------------------------------------------- #
    #   sync_bn     是否使用sync_bn,DDP模式多卡可用
    # --------------------------------------------------------------------- #
    sync_bn         = False
    # --------------------------------------------------------------------- #
    #   fp16        是否使用混合精度训练
    #               可减少约一半的显存、需要pytorch1.7.1以上
    # --------------------------------------------------------------------- #
    fp16            = False
    # ---------------------------------------------------- #
    #   训练自己的数据集的时候一定要注意修改classes_path
    #   修改成自己对应的种类的txt
    # ---------------------------------------------------- #
    classes_path    = 'model_data/cls_classes.txt' 
    # ---------------------------------------------------- #
    #   输入的图片大小
    # ---------------------------------------------------- #
    input_shape     = [224, 224]
    # ------------------------------------------------------ #
    #   所用模型种类:
    #   mobilenetv2、
    #   resnet18、resnet34、resnet50、resnet101、resnet152
    #   vgg11、vgg13、vgg16、vgg11_bn、vgg13_bn、vgg16_bn、
    #   vit_b_16、
    #   swin_transformer_tiny、swin_transformer_small、swin_transformer_base
    # ------------------------------------------------------ #
    backbone        = "resnet50"
    # -------------------------------------------------------------------------------------------------------------- #
    #   是否使用主干网络的预训练权重,此处使用的是主干的权重,因此是在模型构建的时候进行加载的。
    #   如果设置了model_path,则主干的权值无需加载,pretrained的值无意义。
    #   如果不设置model_path,pretrained = True,此时仅加载主干开始训练。
    #   如果不设置model_path,pretrained = False,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。
    # --------------------------------------------------------------------------------------------------------- #
    pretrained      = True
    # --------------------------------------------------------------------------------------------------------- #
    #   权值文件的下载请看README,可以通过网盘下载。模型的 预训练权重 对不同数据集是通用的,因为特征是通用的。
    #   模型的 预训练权重 比较重要的部分是 主干特征提取网络的权值部分,用于进行特征提取。
    #   预训练权重对于99%的情况都必须要用,不用的话主干部分的权值太过随机,特征提取效果不明显,网络训练的结果也不会好
    #
    #   如果训练过程中存在中断训练的操作,可以将model_path设置成logs文件夹下的权值文件,将已经训练了一部分的权值再次载入。
    #   同时修改下方的 冻结阶段 或者 解冻阶段 的参数,来保证模型epoch的连续性。
    #   
    #   当model_path = ''的时候不加载整个模型的权值。
    #
    #   此处使用的是整个模型的权重,因此是在train.py进行加载的,pretrain不影响此处的权值加载。
    #   如果想要让模型从主干的预训练权值开始训练,则设置model_path = '',pretrain = True,此时仅加载主干。
    #   如果想要让模型从0开始训练,则设置model_path = '',pretrain = Fasle,此时从0开始训练。
    # --------------------------------------------------------------------------------------------------------- #
    model_path      = ""
    # --------------------------------------------------------------------------------------------------------- #
    #   训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。
    #   冻结训练需要的显存较小,显卡非常差的情况下,可设置Freeze_Epoch等于UnFreeze_Epoch,此时仅仅进行冻结训练。
    #      
    #   在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整:
    #   (一)从整个模型的预训练权重开始训练: 
    #       Adam:
    #           Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-3。(冻结)
    #           Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-3。(不冻结)
    #       SGD:
    #           Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 200,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2。(冻结)
    #           Init_Epoch = 0,UnFreeze_Epoch = 200,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2。(不冻结)
    #       其中:UnFreeze_Epoch可以在100-300之间调整。
    #   (二)从0开始训练:
    #       Adam:
    #           Init_Epoch = 0,UnFreeze_Epoch = 300,Unfreeze_batch_size >= 16,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-3。(不冻结)
    #       SGD:
    #           Init_Epoch = 0,UnFreeze_Epoch = 300,Unfreeze_batch_size >= 16,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2。(不冻结)
    #       其中:UnFreeze_Epoch尽量不小于300。
    #   (三)batch_size的设置:
    #       在显卡能够接受的范围内,以大为好。显存不足与数据集大小无关,提示显存不足(OOM或者CUDA out of memory)请调小batch_size。
    #       受到BatchNorm层影响,batch_size最小为2,不能为1。
    #       正常情况下Freeze_batch_size建议为Unfreeze_batch_size的1-2倍。不建议设置的差距过大,因为关系到学习率的自动调整。
    # -----------------------------------------------------------------------------------------------------------#
    # ------------------------------------------------------------------#
    #   冻结阶段训练参数
    #   此时模型的主干被冻结了,特征提取网络不发生改变
    #   占用的显存较小,仅对网络进行微调
    #   Init_Epoch          模型当前开始的训练世代,其值可以大于Freeze_Epoch,如设置:
    #                       Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100
    #                       会跳过冻结阶段,直接从60代开始,并调整对应的学习率。
    #                       (断点续练时使用)
    #   Freeze_Epoch        模型冻结训练的Freeze_Epoch
    #                       (当Freeze_Train=False时失效)
    #   Freeze_batch_size   模型冻结训练的batch_size
    #                       (当Freeze_Train=False时失效)
    # ------------------------------------------------------------------ #
    Init_Epoch          = 0
    Freeze_Epoch        = 50
    Freeze_batch_size   = 8
    # ------------------------------------------------------------------ #
    #   解冻阶段训练参数
    #   此时模型的主干不被冻结了,特征提取网络会发生改变
    #   占用的显存较大,网络所有的参数都会发生改变
    #   UnFreeze_Epoch          模型总共训练的epoch
    #   Unfreeze_batch_size     模型在解冻后的batch_size
    # ------------------------------------------------------------------ #
    UnFreeze_Epoch      = 5
    Unfreeze_batch_size = 5
    # ------------------------------------------------------------------ #
    #   Freeze_Train    是否进行冻结训练
    #                   默认先冻结主干训练后解冻训练。
    # ------------------------------------------------------------------ #
    Freeze_Train        = True
    
    # ------------------------------------------------------------------ #
    #   其它训练参数:学习率、优化器、学习率下降有关
    # ------------------------------------------------------------------ #
    # ------------------------------------------------------------------ #
    #   Init_lr         模型的最大学习率
    #                   当使用Adam优化器时建议设置  Init_lr=1e-3
    #                   当使用SGD优化器时建议设置   Init_lr=1e-2
    #   Min_lr          模型的最小学习率,默认为最大学习率的0.01
    # ------------------------------------------------------------------ #
    Init_lr             = 1e-2
    Min_lr              = Init_lr * 0.01
    # ------------------------------------------------------------------ #
    #   optimizer_type  使用到的优化器种类,可选的有adam、sgd
    #                   当使用Adam优化器时建议设置  Init_lr=1e-3
    #                   当使用SGD优化器时建议设置   Init_lr=1e-2
    #   momentum        优化器内部使用到的momentum参数
    #   weight_decay    权值衰减,可防止过拟合
    #                   使用adam优化器时会有错误,建议设置为0
    # ------------------------------------------------------------------ #
    optimizer_type      = "sgd"
    momentum            = 0.9
    weight_decay        = 5e-4
    # ------------------------------------------------------------------ #
    #   lr_decay_type   使用到的学习率下降方式,可选的有step、cos
    # ------------------------------------------------------------------ #
    lr_decay_type       = "cos"
    # ------------------------------------------------------------------ #
    #   save_period     多少个epoch保存一次权值
    # ------------------------------------------------------------------ #
    save_period         = 10
    # ------------------------------------------------------------------ #
    #   save_dir        权值与日志文件保存的文件夹
    # ------------------------------------------------------------------ #
    save_dir            = 'logs'
    # ------------------------------------------------------------------ #
    #   num_workers     用于设置是否使用多线程读取数据
    #                   开启后会加快数据读取速度,但是会占用更多内存
    #                   内存较小的电脑可以设置为2或者0  
    # ------------------------------------------------------------------ #
    num_workers         = 2

    # ------------------------------------------------------ #
    #   train_annotation_path   训练图片路径和标签
    #   test_annotation_path    验证图片路径和标签(使用测试集代替验证集)
    # ------------------------------------------------------ #
    train_annotation_path   = "cls_train.txt"
    test_annotation_path    = 'cls_test.txt'

    # ------------------------------------------------------ #
    #   设置用到的显卡
    # ------------------------------------------------------ #
    ngpus_per_node  = torch.cuda.device_count()
    if distributed:
        dist.init_process_group(backend="nccl")
        local_rank  = int(os.environ["LOCAL_RANK"])
        rank        = int(os.environ["RANK"])
        device      = torch.device("cuda", local_rank)
        if local_rank == 0:
            print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...")
            print("Gpu Device Count : ", ngpus_per_node)
    else:
        device          = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        local_rank      = 0
        rank            = 0

    # ---------------------------------------------------- #
    #   下载预训练权重
    # ---------------------------------------------------- #
    if pretrained:
        if distributed:
            if local_rank == 0:
                download_weights(backbone)  
            dist.barrier()
        else:
            download_weights(backbone)

    # ------------------------------------------------------ #
    #   获取classes
    # ------------------------------------------------------ #
    class_names, num_classes = get_classes(classes_path)

    if backbone not in ['vit_b_16', 'swin_transformer_tiny', 'swin_transformer_small', 'swin_transformer_base']:
        model = get_model_from_name[backbone](num_classes=num_classes, pretrained=pretrained)
    else:
        model = get_model_from_name[backbone](input_shape = input_shape, num_classes=num_classes, pretrained=pretrained)

    if not pretrained:
        weights_init(model)
    if model_path != "":
        # ------------------------------------------------------ #
        #   权值文件请看README,百度网盘下载
        # ------------------------------------------------------ #
        if local_rank == 0:
            print('Load weights {}.'.format(model_path))
        
        # ------------------------------------------------------ #
        #   根据预训练权重的Key和模型的Key进行加载
        # ------------------------------------------------------ #
        model_dict      = model.state_dict()
        pretrained_dict = torch.load(model_path, map_location = device)
        load_key, no_load_key, temp_dict = [], [], {}
        for k, v in pretrained_dict.items():
            if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
                temp_dict[k] = v
                load_key.append(k)
            else:
                no_load_key.append(k)
        model_dict.update(temp_dict)
        model.load_state_dict(model_dict)
        # ------------------------------------------------------ #
        #   显示没有匹配上的Key
        # ------------------------------------------------------ #
        if local_rank == 0:
            print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key))
            print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key))
            print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m")

    # ---------------------- #
    #   记录Loss
    # ---------------------- #
    if local_rank == 0:
        loss_history = LossHistory(save_dir, model, input_shape=input_shape)
    else:
        loss_history = None
        
    # ------------------------------------------------------------------ #
    #   torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16
    #   因此torch1.2这里显示"could not be resolve"
    # ------------------------------------------------------------------ #
    if fp16:
        from torch.cuda.amp import GradScaler as GradScaler
        scaler = GradScaler()
    else:
        scaler = None

    model_train     = model.train()
    # ---------------------------- #
    #   多卡同步Bn
    # ---------------------------- #
    if sync_bn and ngpus_per_node > 1 and distributed:
        model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train)
    elif sync_bn:
        print("Sync_bn is not support in one gpu or not distributed.")

    if Cuda:
        if distributed:
            # ---------------------------- #
            #   多卡平行运行
            # ---------------------------- #
            model_train = model_train.cuda(local_rank)
            model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=True)
        else:
            model_train = torch.nn.DataParallel(model)
            cudnn.benchmark = True
            model_train = model_train.cuda()
        
    # --------------------------- #
    #   读取数据集对应的txt
    # --------------------------- #
    with open(train_annotation_path, encoding='utf-8') as f:
        train_lines = f.readlines()
    with open(test_annotation_path, encoding='utf-8') as f:
        val_lines   = f.readlines()
    num_train   = len(train_lines)
    num_val     = len(val_lines)
    np.random.seed(10101)
    np.random.shuffle(train_lines)
    np.random.seed(None)
    
    if local_rank == 0:
        show_config(
            num_classes = num_classes, backbone = backbone, model_path = model_path, input_shape = input_shape, \
            Init_Epoch = Init_Epoch, Freeze_Epoch = Freeze_Epoch, UnFreeze_Epoch = UnFreeze_Epoch, Freeze_batch_size = Freeze_batch_size, Unfreeze_batch_size = Unfreeze_batch_size, Freeze_Train = Freeze_Train, \
            Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \
            save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train, num_val = num_val
        )
    # --------------------------------------------------------- #
    #   总训练世代指的是遍历全部数据的总次数
    #   总训练步长指的是梯度下降的总次数 
    #   每个训练世代包含若干训练步长,每个训练步长进行一次梯度下降。
    #   此处仅建议最低训练世代,上不封顶,计算时只考虑了解冻部分
    # ---------------------------------------------------------- #
    wanted_step = 3e4 if optimizer_type == "sgd" else 1e4
    total_step  = num_train // Unfreeze_batch_size * UnFreeze_Epoch
    if total_step <= wanted_step:
        wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1
        print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m"%(optimizer_type, wanted_step))
        print("\033[1;33;44m[Warning] 本次运行的总训练数据量为%d,Unfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d。\033[0m"%(num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step))
        print("\033[1;33;44m[Warning] 由于总训练步长为%d,小于建议总步长%d,建议设置总世代为%d。\033[0m"%(total_step, wanted_step, wanted_epoch))

    # ------------------------------------------------------ #
    #   主干特征提取网络特征通用,冻结训练可以加快训练速度
    #   也可以在训练初期防止权值被破坏。
    #   Init_Epoch为起始世代
    #   Freeze_Epoch为冻结训练的世代
    #   UnFreeze_Epoch总训练世代
    #   提示OOM或者显存不足请调小Batch_size
    # ------------------------------------------------------ #
    if True:
        UnFreeze_flag = False
        # ------------------------------------ #
        #   冻结一定部分训练
        # ------------------------------------ #
        if Freeze_Train:
            model.freeze_backbone()

        # ------------------------------------------------------------------- #
        #   如果不冻结训练的话,直接设置batch_size为Unfreeze_batch_size
        # ------------------------------------------------------------------- #
        batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size

        # ------------------------------------------------------------------- #
        #   判断当前batch_size,自适应调整学习率
        # ------------------------------------------------------------------- #
        nbs             = 64
        lr_limit_max    = 1e-3 if optimizer_type == 'adam' else 1e-1
        lr_limit_min    = 1e-4 if optimizer_type == 'adam' else 5e-4
        if backbone in ['vit_b_16', 'swin_transformer_tiny', 'swin_transformer_small', 'swin_transformer_base']:
            nbs             = 256
            lr_limit_max    = 1e-3 if optimizer_type == 'adam' else 1e-1
            lr_limit_min    = 1e-5 if optimizer_type == 'adam' else 5e-4
        Init_lr_fit     = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
        Min_lr_fit      = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
        
        optimizer = {
            'adam'  : optim.Adam(model_train.parameters(), Init_lr_fit, betas = (momentum, 0.999), weight_decay=weight_decay),
            'sgd'   : optim.SGD(model_train.parameters(), Init_lr_fit, momentum = momentum, nesterov=True)
        }[optimizer_type]
        
        # --------------------------------------- #
        #   获得学习率下降的公式
        # --------------------------------------- #
        lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
        
        # --------------------------------------- #
        #   判断每一个世代的长度
        # --------------------------------------- #
        epoch_step      = num_train // batch_size
        epoch_step_val  = num_val // batch_size
        
        if epoch_step == 0 or epoch_step_val == 0:
            raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")

        train_dataset   = DataGenerator(train_lines, input_shape, True)
        val_dataset     = DataGenerator(val_lines, input_shape, False)
        
        if distributed:
            train_sampler   = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True,)
            val_sampler     = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False,)
            batch_size      = batch_size // ngpus_per_node
            shuffle         = False
        else:
            train_sampler   = None
            val_sampler     = None
            shuffle         = True
            
        gen             = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True, 
                                drop_last=True, collate_fn=detection_collate, sampler=train_sampler)
        gen_val         = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                                drop_last=True, collate_fn=detection_collate, sampler=val_sampler)
        # --------------------------------------- #
        #   开始模型训练
        # --------------------------------------- #
        for epoch in range(Init_Epoch, UnFreeze_Epoch):
            # --------------------------------------- #
            #   如果模型有冻结学习部分
            #   则解冻,并设置参数
            # --------------------------------------- #
            if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train:
                batch_size = Unfreeze_batch_size

                # ------------------------------------------------------------------- #
                #   判断当前batch_size,自适应调整学习率
                # ------------------------------------------------------------------- #
                nbs             = 64
                lr_limit_max    = 1e-3 if optimizer_type == 'adam' else 1e-1
                lr_limit_min    = 1e-4 if optimizer_type == 'adam' else 5e-4
                if backbone in ['vit_b_16', 'swin_transformer_tiny', 'swin_transformer_small', 'swin_transformer_base']:
                    nbs             = 256
                    lr_limit_max    = 1e-3 if optimizer_type == 'adam' else 1e-1
                    lr_limit_min    = 1e-5 if optimizer_type == 'adam' else 5e-4
                Init_lr_fit     = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
                Min_lr_fit      = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
                # --------------------------------------- #
                #   获得学习率下降的公式
                # --------------------------------------- #
                lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
                
                model.Unfreeze_backbone()

                epoch_step      = num_train // batch_size
                epoch_step_val  = num_val // batch_size

                if epoch_step == 0 or epoch_step_val == 0:
                    raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")

                if distributed:
                    batch_size = batch_size // ngpus_per_node

                gen             = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                                        drop_last=True, collate_fn=detection_collate, sampler=train_sampler)
                gen_val         = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                                        drop_last=True, collate_fn=detection_collate, sampler=val_sampler)

                UnFreeze_flag = True

            if distributed:
                train_sampler.set_epoch(epoch)
                
            set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
            
            fit_one_epoch(model_train, model, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank)

        if local_rank == 0:
            loss_history.writer.close()

三、评估

"""
定义一个名为Eval_Classification的类,它继承Classification类。


"""
import os

import numpy as np
import torch

from classification import (Classification, cvtColor, letterbox_image,
                            preprocess_input)
from utils.utils import letterbox_image
from utils.utils_metrics import evaluteTop1_5

# ------------------------------------------------------ #
#   test_annotation_path    测试图片路径和标签
# ------------------------------------------------------ #
test_annotation_path    = 'cls_test.txt'
# ------------------------------------------------------ #
#   metrics_out_path        指标保存的文件夹
# ------------------------------------------------------ #
metrics_out_path        = "metrics_out"

class Eval_Classification(Classification):
    def detect_image(self, image):        
        # --------------------------------------------------------- #
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        # --------------------------------------------------------- #
        image       = cvtColor(image)
        # --------------------------------------------------- #
        #   对图片进行不失真的resize
        # --------------------------------------------------- #
        image_data  = letterbox_image(image, [self.input_shape[1], self.input_shape[0]], self.letterbox_image)
        # --------------------------------------------------------- #
        #   归一化+添加上batch_size维度+转置
        # --------------------------------------------------------- #
        image_data  = np.transpose(np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0), (0, 3, 1, 2))

        with torch.no_grad():
            photo   = torch.from_numpy(image_data).type(torch.FloatTensor)
            if self.cuda:
                photo = photo.cuda()
            # --------------------------------------------------- #
            #   图片传入网络进行预测
            # --------------------------------------------------- #
            preds   = torch.softmax(self.model(photo)[0], dim=-1).cpu().numpy()

        return preds

if __name__ == "__main__":
    if not os.path.exists(metrics_out_path):
        os.makedirs(metrics_out_path)
            
    classfication = Eval_Classification()
    
    with open("./cls_test.txt", "r") as f:
        lines = f.readlines()
    top1, top5, Recall, Precision = evaluteTop1_5(classfication, lines, metrics_out_path)
    print("top-1 accuracy = %.2f%%" % (top1*100))
    print("top-5 accuracy = %.2f%%" % (top5*100))
    print("mean Recall = %.2f%%" % (np.mean(Recall)*100))
    print("mean Precision = %.2f%%" % (np.mean(Precision)*100))

    我们来看一下classification代码:

"""
这段代码实现一个图像分类器,用于对输入的图片进行分类。
"""
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn

from nets import get_model_from_name
from utils.utils import (cvtColor, get_classes, letterbox_image,
                         preprocess_input, show_config)


# -------------------------------------------- #
#   使用自己训练好的模型预测需要修改3个参数
#   model_path和classes_path和backbone都需要修改!
# --------------------------------------------#
class Classification(object):
    _defaults = {
        # -------------------------------------------------------------------------- #
        #   使用自己训练好的模型进行预测一定要修改model_path和classes_path!
        #   model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
        #   如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
        # -------------------------------------------------------------------------- #
        "model_path"        : 'logs/best_epoch_weights.pth',
        "classes_path"      : 'model_data/cls_classes.txt',
        # -------------------------------------------------------------------- #
        #   输入的图片大小
        # -------------------------------------------------------------------- #
        "input_shape"       : [224, 224],
        # -------------------------------------------------------------------- #
        #   所用模型种类:
        #   mobilenetv2、
        #   resnet18、resnet34、resnet50、resnet101、resnet152
        #   vgg11、vgg13、vgg16、vgg11_bn、vgg13_bn、vgg16_bn、
        #   vit_b_16、
        #   swin_transformer_tiny、swin_transformer_small、swin_transformer_base
        # -------------------------------------------------------------------- #
        "backbone"          : 'resnet50',
        # -------------------------------------------------------------------- #
        #   该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize
        #   否则对图像进行CenterCrop
        # -------------------------------------------------------------------- #
        "letterbox_image"   : False,
        # ------------------------------- #
        #   是否使用Cuda
        #   没有GPU可以设置成False
        # ------------------------------- #
        "cuda"              : True
    }

    # 定义一个类方法get_defaults,用于获取默认参数值。
    # 它接受一个参数n,表示要获取的参数名,如果该参数名在_defaults字典中,则返回对应的值,否则返回一个错误信息。
    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    # --------------------------------------------------- #
    #   初始化classification
    # --------------------------------------------------- #
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)

        # --------------------------------------------------- #
        #   获得种类
        # --------------------------------------------------- #
        self.class_names, self.num_classes = get_classes(self.classes_path)
        self.generate()

        # 调用show_config函数显示配置信息
        show_config(**self._defaults)

    # --------------------------------------------------- #
    #   获得所有的分类
    # --------------------------------------------------- #
    def generate(self):
        # --------------------------------------------------- #
        #   载入模型与权值
        # --------------------------------------------------- #
        if self.backbone not in ['vit_b_16', 'swin_transformer_tiny', 'swin_transformer_small', 'swin_transformer_base']:
            self.model  = get_model_from_name[self.backbone](num_classes = self.num_classes, pretrained = False)
        else:
            self.model  = get_model_from_name[self.backbone](input_shape = self.input_shape, num_classes = self.num_classes, pretrained = False)
        device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.load_state_dict(torch.load(self.model_path, map_location=device))
        self.model  = self.model.eval()
        print('{} model, and classes loaded.'.format(self.model_path))

        if self.cuda:
            self.model = nn.DataParallel(self.model)
            self.model = self.model.cuda()

    # --------------------------------------------------- #
    #   检测图片,detect_image函数用于对输入图像进行分类,
    #   将会在predict.py中调用。
    # --------------------------------------------------- #
    def detect_image(self, image):
        # --------------------------------------------------------- #
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        # --------------------------------------------------------- #
        image       = cvtColor(image)
        # --------------------------------------------------- #
        #   对图片进行不失真的resize
        # --------------------------------------------------- #
        image_data  = letterbox_image(image, [self.input_shape[1], self.input_shape[0]], self.letterbox_image)
        # --------------------------------------------------------- #
        #   归一化+添加上batch_size维度+转置
        # --------------------------------------------------------- #
        image_data  = np.transpose(np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0), (0, 3, 1, 2))

        with torch.no_grad():
            # 使用torch.from_numpy函数将图片转换为pytorch中的张量,并将其传入模型进行预测。
            photo   = torch.from_numpy(image_data)
            if self.cuda:
                photo = photo.cuda()
            # --------------------------------------------------- #
            #   图片传入网络进行预测
            # --------------------------------------------------- #
            preds   = torch.softmax(self.model(photo)[0], dim=-1).cpu().numpy()
        # --------------------------------------------------- #
        #   获得所属种类
        #   使用np.argmax函数获取概率最大的类别,并根据
        # --------------------------------------------------- #
        class_name  = self.class_names[np.argmax(preds)]
        probability = np.max(preds)

        # --------------------------------------------------- #
        #   绘图并写字
        # --------------------------------------------------- #
        plt.subplot(1, 1, 1)
        # -------------------------------------------------------------------------- #
        # 将输入的image转换为numpy数组,然后使用imshow函数将数组中的像素值作为图像的像素显示出来。
        # -------------------------------------------------------------------------- #
        plt.imshow(np.array(image))
        plt.title('Class:%s Probability:%.3f' %(class_name, probability))
        plt.show()
        return class_name

四、预测

'''
predict.py有几个注意点
1、无法进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。
2、如果想要将预测结果保存成txt,可以利用open打开txt文件,使用write方法写入txt,可以参考一下txt_annotation.py文件。
'''
from PIL import Image

from classification import Classification

classfication = Classification()

while True:
    img = input('Input image filename:')
    try:
        image = Image.open(img)
    except:
        print('Open Error! Try again!')
        continue
    else:
        class_name = classfication.detect_image(image)
        print(class_name)

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值