语义分割之模型预测(inference)

小尺寸图像输入

一般的图像无需裁剪,便可输入模型,进行端到端的训练。它的预测过程也是比较简单的,以二分类为例,将模型输出的概率图通过一定的方法转化为二值图。有两种方法可实现上述过程,其一,若类别数包括背景类,利用argmax输出各维度相同位置处最大值,其二,若类别数不包括背景类,则利用sigmoid压缩其值至0-1之间,利用阈值法,一般为0.5,大于0.5为正类,小于0.5为背景类。

net = torch.load('./model.pth', map_location=lambda storage, loc: storage)["model"]
net = net.to(device)

imglist = os.listdir(input_img_folder)
img = cv2.imread(os.path.join(input_img_folder, imglist[400]))
tensor = img_to_tensor(img)
tensor = Variable(torch.unsqueeze(tensor, dim=0).float(), requires_grad=False)
predict = net(tensor.to(device))[0,0,:,:]
predict = predict.detach().cpu().numpy()
predict[predict <= 0.5] = 0   #背景类
predict[predict > 0.5] = 1   #正类

大尺寸图像输入

当图像尺寸较大时,整体输入模型去训练,很容易导致cuda:out of memory。在遥感图像中,经常遇到这种情况。一般的解决方法就是将大图片裁成切片,当进行模型预测后进行拼接。其步骤为:

(1)获取所有图像路径;
(2) 进行for循环,将每张图像裁成切片,储存在一个临时的文件(完成预测后就删除),并基于此生成数据生成器;
(3)基于数据生成器,进行模型预测,将所有的概率图拼接成大的概率图,其尺寸与原图一样;
(4)将概率图转化为二值图,并根据可视化需求进行上色;
(5)最后删掉临时文件,不断重复(2)(3)(4)。
————————————————

## use model to predict
def predict(model):
    result = []
    for images in tqdm.tqdm(test_loader):
        images = images.to(device)
        temp = 0
        for keys in model.keys():
            model[keys].eval()
            outputs = model[keys](images)
            temp += outputs
        preds = temp/len(model)
        # preds = torch.from_numpy(preds)
        preds = torch.max(preds,1)[1]
        result.append(preds.cpu().numpy())
    return result


def input_and_output(pic_path, model, generate_data):
    """
    args:
        pic_path : the picture you want to predict
        model    : the model you want to predict
    note:
        step one : generate some pictures from one picture
        step two : predict from the images generated by step one 
    """
    image_size = args.crop_size

    img = cv2.imread(pic_path)
    b = args.padding_size
    image = cv2.copyMakeBorder(img, b, b, b, b, cv2.BORDER_REFLECT)
    h, w = image.shape[0], image.shape[1]
    row = img.shape[0]//image_size
    col = img.shape[1]//image_size
    padding_img = np.zeros((h, w, 3), dtype=np.uint8)
    padding_img[0:h, 0:w, :] = image[:, :, :]

    padding_img = np.array(padding_img)
#     print ('src:',padding_img.shape)
    mask_whole = np.zeros((row*image_size, col*image_size), dtype=np.uint8)
    if generate_data == False:
        result = predict(model)
        map_list = [str(i.name) for i in Path('temp_pic').files()]
    for i in range(row):
        for j in range(col):
            if generate_data:
                crop = redundancy_crop(padding_img, i, j, image_size)
                ch,cw,_ = crop.shape
                cv2.imwrite(f'temp_pic/{i}_{j}.png',crop)
            else:
                temp = result[map_list.index(f'{i}_{j}.png')]
                temp = redundancy_crop2(temp)
                mask_whole[i*image_size:i*image_size+image_size,j*image_size:j*image_size+image_size] = temp
    return mask_whole


def redundancy_crop(img, i, j, targetSize):
    temp_img = img[i*targetSize:i*targetSize+targetSize+2*args.padding_size, j*targetSize:j*targetSize+targetSize+2*args.padding_size, :]
    return temp_img


def redundancy_crop2(img):
    h = img.shape[1]
    w = img.shape[2]
    temp_img = img[:,args.padding_size:h-args.padding_size,args.padding_size:w-args.padding_size]
    return temp_img


def get_dataset_loaders( workers):
    batch_size = 1

    test_dataset = urban3dDWM(
        os.path.join(path), './',  test=True
    )

    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=workers)
    return test_loader


def get_labels():
    """Load the mapping that associates pascal classes with label colors

    Returns:
        np.ndarray with dimensions (2, 3)
    """
    return np.asarray(
        [
            [0, 0, 0],
            [255, 255, 255]
        ]
    )


def decode_segmap(label_mask, n_classes):
    """Decode segmentation class labels into a color image

    Args:
        label_mask (np.ndarray): an (M,N) array of integer values denoting
          the class label at each spatial location.
        plot (bool, optional): whether to show the resulting color image
          in a figure.

    Returns:
        (np.ndarray, optional): the resulting decoded color image.
    """
    label_colours = get_labels()
    r = label_mask.copy()
    g = label_mask.copy()
    b = label_mask.copy()
    for ll in range(0, n_classes):
        r[label_mask == ll] = label_colours[ll, 0]
        g[label_mask == ll] = label_colours[ll, 1]
        b[label_mask == ll] = label_colours[ll, 2]
    rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
    rgb[:, :, 0] = r
    rgb[:, :, 1] = g
    rgb[:, :, 2] = b
    return rgb


if __name__ =="__main__":
# def my_predict():
    parse = argparse.ArgumentParser()
    parse.add_argument("--n_class", type=int, default=2, help="the number of classes")
    parse.add_argument("--model_name", type=str, default='UNet', help="UNet,PSPNet,FPN")

    parse.add_argument("--n_workers", type=int, default=4, help="the number of workers")
    parse.add_argument("--crop_size", type=int, default=256, help="the number of workers")
    parse.add_argument("--padding_size", type=int, default=32, help="the number of workers")

    args = parse.parse_args()
    # model_groups = ["UNet","PSPNet","FPN"]
    model_groups = ["UNet"]

# predict on more model
    models={}
    for index, item in enumerate(model_groups):
        models[item] = model = torch.load(f'./results_{item}2/{item}_weights_best.pth')["model_state"]

    # model = torch.load(f'./results_{args.model_name}/{args.model_name}_weights_best.pth')["model_state"]

    imgList = glob.glob("./valid/*RGB.tif")
    num = len(imgList)

    save_path = f'./predict_{args.model_name}'
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    for i in tqdm.tqdm(range(num)):
        if not os.path.exists('temp_pic'):
            os.makedirs('temp_pic')
        ### predict on one picture
        input_and_output(imgList[i], models, generate_data=True)
        name = os.path.split(imgList[i])[-1].split(".")[0]
        test_loader = get_dataset_loaders(args.n_workers)
        mask_result = input_and_output(imgList[i], models, generate_data=False)
        # 递归删除文件夹
        try:
            shutil.rmtree('temp_pic')
        except:
            pass

        decoded = decode_segmap(mask_result, args.n_class)

        # print(mask_result.shape)
        cv2.imwrite(f'{save_path}/{name}.png', decoded)

为避免预测图出现网格化效应,上述代码采用了冗余预测。

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

浪子私房菜

给小强一点爱心呗

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值