1、背景
在进行图像分割预测的时候,我们输入网络的图片往往是正方形的(512x512,224x224甚至更小),那么预测出图片的大小也和输入网络的图片大小一样,是正方形的,怎么让预测出来的图片和原始图片长宽比例相同呢?知道我关注了b站up主Bubbliiing,才把这个问题解决,这里特别感谢。他的代码与思路如下:
2、添加灰度条
由于在图片进入到网络之前,往往需要裁剪或者resize,那么对于原始图片是长方形的,如果裁剪,则会丢失信息,如果resize则会失真,因此添加灰度条是保证图片既不会失真也不会丢失信息的一个方法。
def letterbox_image(self ,image, size):
image = image.convert("RGB")
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image,nw,nh
3、思路与代码
第一步:进行原始图片的备份,并计算图片的高和宽。
第二步:对输入进来的图片添加灰度条,其本质是不失真的resize。
第三步:对输入图片进行归一化,所有像素点全部除以255,同时加上batch_size维度,并进行transpose。
第四步:图片传入网络进行预测,对预测的结果进行permute操作将通道数转到最后一维。
第五步:对预测结果取一个softmax操作,取出每个像素点对应最大概率值的种类。
第六步:对最终的预测结果进行截取,将灰条部分截取掉。
第七步:判断每个像素点的种类,对每个像素点分配一个特定的颜色。
第八步:将得到的分割图像转化为image并进行resize,与原始图片进行复合即可。
下面看代码:
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image):
# 进行原始图片的备份
old_img = copy.deepcopy(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
image, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0]))
images = [np.array(image)/255]
images = np.transpose(images,(0,3,1,2))
with torch.no_grad():
images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
if self.cuda:
images =images.cuda()
pr = self.net(images)[0]
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1)
pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]
seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3))
for c in range(self.num_classes):
seg_img[:,:,0] += ((pr[:,: ] == c )*( self.colors[c][0] )).astype('uint8')
seg_img[:,:,1] += ((pr[:,: ] == c )*( self.colors[c][1] )).astype('uint8')
seg_img[:,:,2] += ((pr[:,: ] == c )*( self.colors[c][2] )).astype('uint8')
image = Image.fromarray(np.uint8(seg_img)).resize((orininal_w,orininal_h))
if self.blend:
image = Image.blend(old_img,image,0.7)
return image
代码选自这里。