import os
import imgviz
import torchvision.transforms as transforms
from models.model_stages import BiSeNet #暂时注释 #导入搭建好的模型
import torch
import numpy as np
from PIL import Image
import copy
#模型定义
n_classes = 11
#是否使用某一浅层特征指导训练
use_boundary_16 = False
use_boundary_8 = False
use_boundary_4 = False
use_boundary_2 = False
use_conv_last = False
# create model
backbone = 'STDCNet1446'
net = BiSeNet(backbone=backbone, n_classes=n_classes, pretrain_model='',
use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4, use_boundary_8=use_boundary_8,
use_boundary_16=use_boundary_16, use_conv_last=use_conv_last)
ckpt = r"C:\Users\18312\PycharmProjects\RethinkBiseNet_myselfDataSet\pths\voc_model_maxmIOU75.pth"
if not ckpt is None:
print('载入模型')
net.load_state_dict(torch.load(ckpt, map_location='cpu'))
net.cuda()
net.eval()
#传入Img读取的一张图片 返回模板上色后的图片 8位 p模式
palette = np.load('./camp2.npy').tolist()
def predict_to_palette(img,palette):
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
im = to_tensor(img)
im = torch.unsqueeze(im, dim=0) # 增加一个维度
im = im.cuda()
with torch.no_grad():
# predict class
logits = net(im)[0]
label = im
label = label.squeeze(1)
size = label.size()[-2:]
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1)
preds = preds.squeeze(1)
preds = preds.cpu()
pr = np.array(preds).reshape(size)
img2 = Image.fromarray(np.uint8(pr))
img2.convert('P', palette=palette)
img2.putpalette(palette)
return img2
def predict_to_vis(img):
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
im = to_tensor(img)
im = torch.unsqueeze(im, dim=0) # 增加一个维度
im = im.cuda()
with torch.no_grad():
# predict class
logits = net(im)[0]
label = im
label = label.squeeze(1)
size = label.size()[-2:]
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1)
preds = preds.squeeze(1)
preds = preds.cpu()
pr = np.array(preds).reshape(size)
tmp_img = np.array(img)
class_names = ["_background_", "car", "building", "tree", "road", "obstacle", "river", "Level lawn","Horizontal roof", "plant", "level ground"]
viz = imgviz.label2rgb(
label=pr, # 预测结果
img=imgviz.rgb2gray(tmp_img), # 原图
font_size=15,
label_names=class_names,
loc="rb",
)
return viz
img_path = r'E:\VOC2007\JPEGImages'
save_path1 = r'E:\predict_viz\viz_1'
save_path2 = r'E:\predict_viz\viz_2'
def predict_to_save(img_path,save_path1,save_path2):
palette = np.load('./camp2.npy').tolist()
for dr in os.listdir(img_path):
path = os.path.join(img_path,dr)
img = Image.open(path)
img2 = predict_to_palette(img,palette)
tmp_dr = dr.split('.')[0] + '.png'
img2.save(os.path.join(save_path1,tmp_dr))
viz = predict_to_vis(img)
imgviz.io.imsave(os.path.join(save_path2,tmp_dr), viz)
print('存储完毕!')
predict_to_save(img_path,save_path1,save_path2)
第一张图片为viz1的结果,第二章图片为viz2的结果。