from __future__ import print_function
import torch
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import time
import cv2
# from models.Eca_ASP_v6 import eca_ASP_v6
# from models.res import res
from models.res_eca import res_eca
# from models.Eca_ASP_1x import eca_ASP_1x
# from models.res_eca_v3head import res_eca_v3head
# from models.res_eca_v3plushead import res_eca_v3plushead
# from models.res_wudecorder import res_wudecorder
# from models.res_eca_v3head import res_eca_v3head
#---------------------------------------------------------------#
val_path = './test/384/val/'
predict_path = './test/pred/'
valList = os.listdir(val_path)
yuzhi = 0.5 #0.4最高
num = len(valList)
# pathlabel = './test/pred/'
pathout = './test/pred_he_acc_0.5/'
gt_path = './test/target/'
use_overlap = False
pinjie = True
h = 384
w = 384
def yuce():
# model = eca_ASP_v6(layers=50, classes=1, pretrained=False,use_aux=True)
# model = res(layers=50, classes=1, pretrained=False, use_aux=False)
# model_path = './checkpoint/res/model/netG_best_acc.pth
# model_path = './checkpoint/Eca_ASPx2_aux_wubn/model/netG_best_acc.pth'
model = res_eca(layers=101, classes=1, pretrained=False, use_aux=True)
model_path = './checkpoint/101/netG_best_acc.pth'
# model = eca_ASP_1x(layers=50, classes=1, pretrained=False, use_aux=False)
# model_path = './checkpoint/eca_ASP_1x/model/netG_best_acc.pth'
model.load_state_dict(torch.load(model_path,map_location='cuda'))
# checkpoint_finetune = torch.load(model_path,map_location='cuda')
# model_dict = model.state_dict()
#
# pretrained_dict = {k: v for k, v in checkpoint_finetune.items() if k in model_dict}
# model_dict.update(pretrained_dict)
# model.load_state_dict(model_dict)
model = model.cuda()
#model.cuda()
model.eval()
gtList = os.listdir(gt_path)
try:
os.makedirs(predict_path)
except OSError:
pass
try:
os.makedirs(pathout)
except OSError:
pass
#---------------------------------------------------------------#
transform1 = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.31701732, 0.32337377, 0.28925751],
std=[0.17323045, 0.16700189, 0.16922423]) # 标准化至[-1,1]
]
)
start = time.time()
for i in tqdm(range(num)):
#test_image = Image.open(val_path+ str(i) + '.png').convert('RGB')
test_image = Image.open(val_path + valList[i]).convert('RGB')
img = transform1(test_image)
img = img.unsqueeze(0)
img = img.cuda()
label_image,a = model(img)
label_image = label_image.squeeze(0)
label_image[label_image >= yuzhi] = 1
label_image[label_image < yuzhi] = 0
label_image = label_image.cpu()
a = transforms.ToPILImage()(label_image)
a.save(predict_path+ valList[i])
#print('转换第%d张图' % (i))
i = i + 1
end = time.time()
print('Program processed ', end - start, 's, ', (end - start) / 60, 'min ')
def pinjie():
if pinjie:
gtList = os.listdir(gt_path)
num_pic = len(gtList)
i = 0 # 图片起始数字
if use_overlap:
num_pin = (1536 // h * 2 - 1) # 横纵数量
else:
num_pin = 1536 // h # 横纵数量
if use_overlap:
for l in tqdm(range(num_pic)):
for j in range(num_pin): # j=纵行
for k in range(num_pin): # k=横行
# 余数为0,代表图片从0读取
if i % num_pin == 0:
path = predict_path + str(i) + ".tif"
img_heng = cv2.imread(path, 0) # 0是按灰度模式读入
if j==0 and j!= (num_pin - 1):
img_heng = img_heng[0: (h-h//4), 0: (w-w//4)] #左上角
if j== (num_pin - 1):
img_heng = img_heng[ h // 4: h , w //4: w] # 左下角
if j != 0 and j != (num_pin - 1) :
img_heng = img_heng[h // 4: (h - h // 4), 0: (w - w // 4)] # 左一侧
else:
path = predict_path + str(i) + ".tif"
img_tmp = cv2.imread(path, 0) # 0是按灰度模式读入
if j==0 and k != (num_pin - 1):
img_tmp = img_tmp[0: (h-h//4), w//4: (w-w//4)] #第一行
if j==0 and k== (num_pin - 1):
img_tmp = img_tmp[0: (h - h // 4), w // 4:w ] # 右上角
if j==(num_pin - 1) and k!= (num_pin - 1):
img_tmp = img_tmp[(h//4): h, w//4: (w-w//4)] #最后一行
if j== (num_pin - 1) and k== (num_pin - 1):
img_tmp = img_tmp[(h//4): h , w // 4:w ] # 右下角
if j!=0 and j!=(num_pin - 1) and k==(num_pin - 1):
img_tmp = img_tmp[h // 4: (h - h // 4), w // 4: w ] # 右一侧
if j != 0 and j != (num_pin - 1) and k != (num_pin - 1) :
img_tmp = img_tmp[h // 4: (h - h // 4), w // 4: (w-w//4)] #中间
img_heng = np.hstack((img_heng, img_tmp)) # 横行拼接
i = i + 1
if j == 0:
img_out = img_heng
else:
img_out = np.vstack((img_out, img_heng)) # 纵行拼接
out_path = pathout + gtList[l]
img_out = img_out[18:1518, 18:1518] # 合成后切割
cv2.imwrite(out_path, img_out) # 输出图片
else:
for l in tqdm(range(num_pic)):
for j in range(num_pin):
for k in range(num_pin):
# 余数为0,代表图片从0读取
if i % num_pin == 0:
path = predict_path + str(i) + ".tif"
img_heng = cv2.imread(path,0) #0是按灰度模式读入
else:
path = predict_path + str(i) + ".tif"
img_tmp = cv2.imread(path,0) #0是按灰度模式读入
img_heng = np.hstack((img_heng, img_tmp)) #横行拼接
i = i + 1
if j==0:
img_out = img_heng
else:
img_out = np.vstack((img_out, img_heng)) #纵行拼接
out_path = pathout + gtList[l]
img_out = img_out[ 18:1518, 18:1518 ] #合成后切割
cv2.imwrite(out_path, img_out) # 输出图片
if __name__ == '__main__':
yuce()
pinjie()
bce_predict阈值分割.py
最新推荐文章于 2023-04-18 16:50:11 发布