小尺寸图像输入
一般的图像无需裁剪,便可输入模型,进行端到端的训练。它的预测过程也是比较简单的,以二分类为例,将模型输出的概率图通过一定的方法转化为二值图。有两种方法可实现上述过程,其一,若类别数包括背景类,利用argmax输出各维度相同位置处最大值,其二,若类别数不包括背景类,则利用sigmoid压缩其值至0-1之间,利用阈值法,一般为0.5,大于0.5为正类,小于0.5为背景类。
net = torch.load('./model.pth', map_location=lambda storage, loc: storage)["model"]
net = net.to(device)
imglist = os.listdir(input_img_folder)
img = cv2.imread(os.path.join(input_img_folder, imglist[400]))
tensor = img_to_tensor(img)
tensor = Variable(torch.unsqueeze(tensor, dim=0).float(), requires_grad=False)
predict = net(tensor.to(device))[0,0,:,:]
predict = predict.detach().cpu().numpy()
predict[predict <= 0.5] = 0 #背景类
predict[predict > 0.5] = 1 #正类
大尺寸图像输入
当图像尺寸较大时,整体输入模型去训练,很容易导致cuda:out of memory。在遥感图像中,经常遇到这种情况。一般的解决方法就是将大图片裁成切片,当进行模型预测后进行拼接。其步骤为:
(1)获取所有图像路径;
(2) 进行for循环,将每张图像裁成切片,储存在一个临时的文件(完成预测后就删除),并基于此生成数据生成器;
(3)基于数据生成器,进行模型预测,将所有的概率图拼接成大的概率图,其尺寸与原图一样;
(4)将概率图转化为二值图,并根据可视化需求进行上色;
(5)最后删掉临时文件,不断重复(2)(3)(4)。
————————————————
## use model to predict
def predict(model):
result = []
for images in tqdm.tqdm(test_loader):
images = images.to(device)
temp = 0
for keys in model.keys():
model[keys].eval()
outputs = model[keys](images)
temp += outputs
preds = temp/len(model)
# preds = torch.from_numpy(preds)
preds = torch.max(preds,1)[1]
result.append(preds.cpu().numpy())
return result
def input_and_output(pic_path, model, generate_data):
"""
args:
pic_path : the picture you want to predict
model : the model you want to predict
note:
step one : generate some pictures from one picture
step two : predict from the images generated by step one
"""
image_size = args.crop_size
img = cv2.imread(pic_path)
b = args.padding_size
image = cv2.copyMakeBorder(img, b, b, b, b, cv2.BORDER_REFLECT)
h, w = image.shape[0], image.shape[1]
row = img.shape[0]//image_size
col = img.shape[1]//image_size
padding_img = np.zeros((h, w, 3), dtype=np.uint8)
padding_img[0:h, 0:w, :] = image[:, :, :]
padding_img = np.array(padding_img)
# print ('src:',padding_img.shape)
mask_whole = np.zeros((row*image_size, col*image_size), dtype=np.uint8)
if generate_data == False:
result = predict(model)
map_list = [str(i.name) for i in Path('temp_pic').files()]
for i in range(row):
for j in range(col):
if generate_data:
crop = redundancy_crop(padding_img, i, j, image_size)
ch,cw,_ = crop.shape
cv2.imwrite(f'temp_pic/{i}_{j}.png',crop)
else:
temp = result[map_list.index(f'{i}_{j}.png')]
temp = redundancy_crop2(temp)
mask_whole[i*image_size:i*image_size+image_size,j*image_size:j*image_size+image_size] = temp
return mask_whole
def redundancy_crop(img, i, j, targetSize):
temp_img = img[i*targetSize:i*targetSize+targetSize+2*args.padding_size, j*targetSize:j*targetSize+targetSize+2*args.padding_size, :]
return temp_img
def redundancy_crop2(img):
h = img.shape[1]
w = img.shape[2]
temp_img = img[:,args.padding_size:h-args.padding_size,args.padding_size:w-args.padding_size]
return temp_img
def get_dataset_loaders( workers):
batch_size = 1
test_dataset = urban3dDWM(
os.path.join(path), './', test=True
)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=workers)
return test_loader
def get_labels():
"""Load the mapping that associates pascal classes with label colors
Returns:
np.ndarray with dimensions (2, 3)
"""
return np.asarray(
[
[0, 0, 0],
[255, 255, 255]
]
)
def decode_segmap(label_mask, n_classes):
"""Decode segmentation class labels into a color image
Args:
label_mask (np.ndarray): an (M,N) array of integer values denoting
the class label at each spatial location.
plot (bool, optional): whether to show the resulting color image
in a figure.
Returns:
(np.ndarray, optional): the resulting decoded color image.
"""
label_colours = get_labels()
r = label_mask.copy()
g = label_mask.copy()
b = label_mask.copy()
for ll in range(0, n_classes):
r[label_mask == ll] = label_colours[ll, 0]
g[label_mask == ll] = label_colours[ll, 1]
b[label_mask == ll] = label_colours[ll, 2]
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
rgb[:, :, 0] = r
rgb[:, :, 1] = g
rgb[:, :, 2] = b
return rgb
if __name__ =="__main__":
# def my_predict():
parse = argparse.ArgumentParser()
parse.add_argument("--n_class", type=int, default=2, help="the number of classes")
parse.add_argument("--model_name", type=str, default='UNet', help="UNet,PSPNet,FPN")
parse.add_argument("--n_workers", type=int, default=4, help="the number of workers")
parse.add_argument("--crop_size", type=int, default=256, help="the number of workers")
parse.add_argument("--padding_size", type=int, default=32, help="the number of workers")
args = parse.parse_args()
# model_groups = ["UNet","PSPNet","FPN"]
model_groups = ["UNet"]
# predict on more model
models={}
for index, item in enumerate(model_groups):
models[item] = model = torch.load(f'./results_{item}2/{item}_weights_best.pth')["model_state"]
# model = torch.load(f'./results_{args.model_name}/{args.model_name}_weights_best.pth')["model_state"]
imgList = glob.glob("./valid/*RGB.tif")
num = len(imgList)
save_path = f'./predict_{args.model_name}'
if not os.path.exists(save_path):
os.makedirs(save_path)
for i in tqdm.tqdm(range(num)):
if not os.path.exists('temp_pic'):
os.makedirs('temp_pic')
### predict on one picture
input_and_output(imgList[i], models, generate_data=True)
name = os.path.split(imgList[i])[-1].split(".")[0]
test_loader = get_dataset_loaders(args.n_workers)
mask_result = input_and_output(imgList[i], models, generate_data=False)
# 递归删除文件夹
try:
shutil.rmtree('temp_pic')
except:
pass
decoded = decode_segmap(mask_result, args.n_class)
# print(mask_result.shape)
cv2.imwrite(f'{save_path}/{name}.png', decoded)
为避免预测图出现网格化效应,上述代码采用了冗余预测。