#!/usr/bin/python
# -*- encoding: utf-8 -*-
from logger import setup_logger
from model import BiSeNet
import torch
import os
import os.path as osp
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import cv2
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
# Colors for all 20 parts
part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
[255, 0, 85], [255, 0, 170],
[0, 255, 0], [85, 255, 0], [170, 255, 0],
[0, 255, 85], [0, 255, 170],
[0, 0, 255], [85, 0, 255], [170, 0, 255],
[0, 85, 255], [0, 170, 255],
[255, 255, 0], [255, 255, 85], [255, 255, 170],
[255, 0, 255], [255, 85, 255], [255, 170, 255],
[0, 255, 255], [85, 255, 255], [170, 255, 255]]
im = np.array(im)
vis_im = im.copy().astype(np.uint8)
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
num_of_class = np.max(vis_parsing_anno)
for pi in range(1, num_of_class + 1):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
# print(vis_parsing_anno_color.shape, vis_im.shape)
vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
# Save result or not
if save_im:
cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno)
cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
# return vis_im
def evaluate(respth='./res/test_res', dspth='./test_data', cp='79999_iter.pth'):
if not os.path.exists(respth):
os.makedirs(respth)
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
save_pth = osp.join('res/cp', cp) #weight
net.load_state_dict(torch.load(save_pth))
net.eval()
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
with torch.no_grad():
for image_path in os.listdir(dspth):
img = Image.open(osp.join(dspth, image_path))
image = img.resize((512, 512), Image.BILINEAR)
img = to_tensor(image)
img = torch.unsqueeze(img, 0)
img = img.cuda()
out = net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
# print(parsing)
print(np.unique(parsing))
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))
if __name__ == "__main__":
evaluate() #dspth='./res', cp='79999_iter.pth'
这段代码是用于图像语义分割的任务。代码主要进行了模型的加载、图像的预处理、模型的推理以及结果的可视化等步骤。下面我将会逐行进行解释:
-
from logger import setup_logger
:从logger
模块导入setup_logger
函数,用于设置日志。 -
from model import BiSeNet
:从model
模块导入BiSeNet
模型,BiSeNet是一个用于语义分割的深度学习模型。 -
import torch
:导入PyTorch库,它是一个开源的深度学习框架。 -
import os
和import os.path as osp
:导入Python的os
模块和os.path
模块,用于处理文件和目录。 -
import numpy as np
:导入NumPy库,一个用于处理数组数据的库。 -
from PIL import Image
:从PIL
模块导入Image
,用于处理图像。 -
import torchvision.transforms as transforms
:导入torchvision.transforms
模块,这个模块提供了许多图像预处理的方法。 -
import cv2
:导入OpenCV库,一个包含了各种图像处理功能的库。
9-30. 定义了vis_parsing_maps
函数,这个函数的目的是可视化分割结果。它通过将分割结果(不同的类别)映射到不同颜色上来进行可视化。
31-65. 定义了evaluate
函数,这个函数用于载入模型并进行推理。它首先创建存储结果的文件夹,然后加载BiSeNet模型,然后对输入文件夹中的每一张图片,进行预处理并通过模型进行推理。之后,使用vis_parsing_maps
函数对推理结果进行可视化,并保存。
66-67. 如果这个脚本作为主程序运行,那么就调用evaluate
函数,传入图片文件夹路径和模型文件的路径。在这个例子中,图片文件夹路径是'./res',模型文件路径是'79999_iter.pth'。
这里是每一行代码的详细解释:
-
from logger import setup_logger
:从logger
模块导入setup_logger
函数,但是在这段代码中没有使用到,这个函数通常用于设置日志记录器。 -
from model import BiSeNet
:从model
模块导入BiSeNet
模型,BiSeNet是一个用于语义分割的深度学习模型。 -
import torch
:导入PyTorch库,一个基于Python的科学计算包,用于深度学习研究和开发。 -
import os
:导入Python的os
模块,一个常用的模块,提供了非常丰富的方法用来处理文件和目录。 -
import os.path as osp
:导入Python的os.path
模块,并给它一个简写的别名osp
,这个模块主要用于获取、检查、修改文件或目录路径。 -
import numpy as np
:导入NumPy库,并给它一个别名np
,NumPy是Python的一个数值计算扩展库,提供了矩阵,线性代数,傅立叶变换等等。 -
from PIL import Image
:从Python的图像处理库PIL(Python Image Library)中导入Image
模块,用于打开、处理图像。 -
import torchvision.transforms as transforms
:导入PyTorch中的torchvision.transforms
模块,并给它一个别名transforms
,这个模块提供了许多图像预处理的方法。 -
import cv2
:导入OpenCV库,一个专门用于实时计算机视觉的库。
10-16. 定义函数vis_parsing_maps
,它有五个参数:输入图像im
,分割注解parsing_anno
,步长stride
,是否保存图像save_im
,以及保存路径save_path
。
17-29. part_colors
是一个二维数组,包含了24种RGB颜色,用于对分割的不同部分进行着色。
31-34. 把输入图像转为NumPy数组,并且复制一份原始图像vis_im
和分割注解vis_parsing_anno
,它们的数据类型都转为uint8
。
-
使用OpenCV的
resize
函数,将分割注解的尺寸根据步长放大。 -
初始化一张和放大后的
vis_parsing_anno
大小相同,值全部为255的3通道图像vis_parsing_anno_color
,用于存储着色后的分割图。 -
获取
vis_parsing_anno
中的最大值,即类别的数量。
38-40. 遍历每一类,将分割图中每一类的部分颜色化。
-
将着色后的
vis_parsing_anno_color
转为uint8
类型。 -
将原始图像和着色后的分割图进行加权混合,得到最终的可视化图像。
43-46. 判断如果需要保存图像,就将着色的分割图和混合后的图像保存为.png和.jpg文件。
48-65. 定义evaluate
函数,接受三个参数,结果保存路径respth
,测试数据路径dspth
,和模型权重文件路径cp
。
49-51. 检查结果保存路径是否存在,如果不存在,就创建这个目录。
- 初始化一个BiSeNet模型,并将模型移动到GPU上。
53-54. 拼接得到模型文件的完整路径,然后加载模型。
- 将模型设置为评估模式,这样在进行前向传播时,不会计算梯度,也不会进行dropout。
57-60. 定义图像预处理操作,包括转为张量,和归一化处理。
61-64. 在不计算梯度的情况下,遍历数据文件夹中的每一张图像,将图像预处理后输入到模型中,得到输出结果,然后将输出结果可视化,并保存。
66-67. 如果这个脚本作为主程序运行,那么就调用evaluate
函数,传入数据文件夹路径和模型文件的路径。在这个例子中,数据文件夹路径是'./res',模型文件路径是'79999_iter.pth'。
# Colors for all 20 parts
# Colors for all 20 parts
part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
[255, 0, 85], [255, 0, 170],
[0, 255, 0], [85, 255, 0], [170, 255, 0],
[0, 255, 85], [0, 255, 170],
[0, 0, 255], [85, 0, 255], [170, 0, 255],
[0, 85, 255], [0, 170, 255],
[255, 255, 0], [255, 255, 85], [255, 255, 170],
[255, 0, 255], [255, 85, 255], [255, 170, 255],
[0, 255, 255], [85, 255, 255], [170, 255, 255]]
这段代码定义了一个名为part_colors
的二维列表,其中包含24个子列表,每个子列表包含3个元素。每个子列表代表一个RGB颜色值,用来在可视化中区分不同的语义类别。
RGB色彩模式是工业界的一种颜色标准,是通过对红(Red)、绿(Green)、蓝(Blue)三个颜色通道的变化以及它们相互之间的叠加来得到各式各样的颜色的,RGB即是代表红、绿、蓝三个通道的颜色,这个标准几乎包括了人类视力所能感知的所有颜色。
例如,[255, 0, 0]
代表纯红色,[0, 255, 0]
代表纯绿色,[0, 0, 255]
代表纯蓝色。[255, 255, 0]
则代表红色和绿色的叠加,即黄色,等等。对于分割任务来说,每一个类别都被映射到了这24种颜色中的一种,以便于在可视化结果图像中区分不同的类别。
在这个脚本的vis_parsing_maps
函数中,将分割结果中的每一类别映射到part_colors
中的一种颜色上,用于生成颜色化的分割结果图像。
import torchvision.transforms as transforms 这句代码是什么意思
这行代码从torchvision
库中导入了transforms
模块,并且使用了别名transforms
。
torchvision
是PyTorch的一个附属库,专门用于处理图片和视频,它包含了许多处理图像的预处理操作,比如裁剪,旋转,调整大小,调整亮度,对比度,饱和度,色调,噪声,等等。
transforms
模块中提供了大量的类和函数,用于进行图像预处理和数据增强。比如:
transforms.Resize
:用于调整图像的尺寸。transforms.CenterCrop
:用于裁剪图像的中心区域。transforms.ToTensor
:用于将PIL图像或者NumPy数组转为PyTorch张量。transforms.Normalize
:用于对图像进行归一化。
在这个脚本中,transforms
模块被用于创建图像预处理的pipeline:
pythonCopy code
to_tensor = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ])
这段代码首先将图像转为张量,然后对张量进行归一化。对图像的RGB通道进行归一化能够加速模型的收敛,减少模型训练的时间,也有可能提高模型的性能。归一化的参数是根据ImageNet数据集的统计结果得到的。
n_classes = 19
net = BiSeNet(n_classes=n_classes) 这句代码是什么意思
这两行代码定义了一个新的神经网络实例,其网络结构是BiSeNet
,并且指定了神经网络的分类数量(也就是输出的维度)为19。
n_classes = 19
这行代码定义了一个名为n_classes
的变量,其值为19。这个变量表示的是网络用于分类的类别数量。在语义分割任务中,这通常指的是需要区分的语义类别的数量,例如,如果是在城市街景数据集(Cityscapes)中进行分割,这个值可能是包含各种不同类型的对象,如建筑、汽车、行人等。
net = BiSeNet(n_classes=n_classes)
这行代码创建了一个新的BiSeNet
的实例,并将n_classes
作为参数传递给BiSeNet
的构造函数,来指定网络输出的类别数量。BiSeNet
是一种专门用于进行语义分割的深度神经网络,它通过构建明确的分割分支,有效地合并了高分辨率的细节信息和低分辨率的语义信息,以实现高效的像素级分割。
所以,简而言之,这两行代码定义了一个具有19个输出类别的BiSeNet
网络实例。
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 这句代码是什么意思
这行代码使用了NumPy数组的astype
方法将数组vis_parsing_anno_color
的元素类型转换为uint8
类型。
NumPy的astype
函数可以创建原数组的一个新的副本,并为副本中的元素指定新的数据类型。在这个例子中,astype(np.uint8)
表示创建一个新的数组,其中的元素类型都为np.uint8
。
uint8
是一种数据类型,表示无符号(即非负)的8位整数,取值范围从0到255。这是图像处理中常用的一种数据类型,因为像素的每个通道的亮度值一般都是在0到255之间。在将vis_parsing_anno_color
数组用于OpenCV的cv2.addWeighted
函数之前,需要先将其元素的类型转换为np.uint8
,否则可能会因为数据类型不匹配而引发错误。
所以,这行代码的作用是将vis_parsing_anno_color
数组中的元素的类型转换为np.uint8
,以便于后续的图像处理操作。
for pi in range(1, num_of_class + 1):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
这个for
循环在代码中的作用是遍历vis_parsing_anno
数组中的每一种类别,并根据类别值对应地设置vis_parsing_anno_color
数组中的像素的颜色。
让我们逐行分析这个循环:
-
for pi in range(1, num_of_class + 1):
:这行代码中的range(1, num_of_class + 1)
产生一个范围从1到num_of_class
的整数序列,每次循环pi
会取这个序列中的一个值。这里注意range
函数的上限是不包括的,所以num_of_class + 1
确保pi
的值会到num_of_class
。这个循环实际上是在遍历所有的类别。 -
index = np.where(vis_parsing_anno == pi)
:这行代码找到vis_parsing_anno
数组中值为pi
的元素的位置,并将这些位置存储在index
变量中。np.where
函数返回一个元组,其中包含若干个数组,每个数组的元素是满足条件的元素在对应维度上的索引。 -
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
:这行代码将vis_parsing_anno_color
数组中对应的像素的颜色设置为part_colors[pi]
。这里,index[0]
和index[1]
是vis_parsing_anno
数组中值为pi
的元素在两个维度上的索引,:
表示在第三个维度上(也就是颜色通道上)选取所有元素。 -
所以,这个
for
循环的作用是将每个语义类别映射为一种颜色,然后将分割结果vis_parsing_anno
转换为彩色图像vis_parsing_anno_color
,以便于进行可视化。