import argparse
import numpy as np
import torch
from PIL import Image
from graphs.models.deeplab50_ClassINW import Res50_ClassINW
import cv2 as cv
IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
# 对应的colormap
label_colours = [
# [ 0, 0, 0],
[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[0, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32],
[0, 0, 0]] # the color of ignored label(-1)
label_colours = list(map(tuple, label_colours))
def img_transform(image):
image = np.asarray(image, np.float32)
image = image[:, :, ::-1] # change to BGR
image -= IMG_MEAN
image = image.transpose((2, 0, 1)).copy() # (C x H x W)
new_image = torch.from_numpy(image)
return new_image
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
def add_test_args(arg_parser):
arg_parser.add_argument('--selected_classes', default=[0, 10, 2, 1, 8],
help="poly_power")
arg_parser.add_argument('--imagenet_pretrained', type=str2bool, default=True,
help="whether apply imagenet pretrained weights")
arg_parser.add_argument('--num_classes', default=19, type=int,
help='num class of mask')
return arg_parser
def load_test_image(image_path):
image = Image.open(image_path).convert("RGB")
image = image.resize((1024, 512), Image.BICUBIC)
image = img_transform(image)
return image
def show_result(result):
seg = result[0]
color_seg = np.zeros((result.shape[1], result.shape[2], 3), dtype=np.uint8)
for label, color in enumerate(label_colours):
color_seg[seg == label, :] = color
color_seg = color_seg[..., ::-1]
return color_seg
if __name__ == '__main__':
# checkpoint路径
checkpoint_file = r'C:\Users\63108\Desktop\SAN-SAW-main\tools\log\gta5_pretrain_2\best.pt'
# 需要测试的图片
image_path = r'C:\Users\63108\Desktop\SAN-SAW-main\DATASET\datasets_original\Cityscapes\leftImg8bit\val\frankfurt\frankfurt_000000_000294_leftImg8bit.png'
# 需要的参数
arg_parser = argparse.ArgumentParser()
arg_parser = add_test_args(arg_parser)
args = arg_parser.parse_args()
# 构建model
model = Res50_ClassINW(args, num_classes=args.num_classes,
pretrained=args.imagenet_pretrained)
# 加载训练好的权重
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
saved_state_dict = torch.load(checkpoint_file)
model.load_state_dict(saved_state_dict)
model.to(device)
model.eval()
test_image = load_test_image(image_path)
test_image = torch.unsqueeze(test_image, dim=0)
test_image = test_image.to(device)
result = model(test_image)
pred = result[0]
pred = pred.data.cpu().numpy()
argpred = np.argmax(pred, axis=1)
result = show_result(argpred)
cv.imwrite('1.png', result)
输入你的color_map
加载你的权重文件和你的测试图片