bce_predict阈值分割.py

from __future__ import print_function
import torch
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import time
import cv2
# from models.Eca_ASP_v6 import eca_ASP_v6
# from models.res import res
from models.res_eca import res_eca
# from models.Eca_ASP_1x import eca_ASP_1x
# from models.res_eca_v3head import res_eca_v3head
# from models.res_eca_v3plushead import res_eca_v3plushead
# from models.res_wudecorder import res_wudecorder
# from models.res_eca_v3head import res_eca_v3head

#---------------------------------------------------------------#
val_path = './test/384/val/'
predict_path = './test/pred/'
valList = os.listdir(val_path)
yuzhi = 0.5       #0.4最高
num = len(valList)
# pathlabel = './test/pred/'
pathout = './test/pred_he_acc_0.5/'
gt_path = './test/target/'
use_overlap = False
pinjie = True
h = 384
w = 384
def yuce():
    # model = eca_ASP_v6(layers=50,  classes=1, pretrained=False,use_aux=True)
    # model = res(layers=50, classes=1, pretrained=False, use_aux=False)
    # model_path = './checkpoint/res/model/netG_best_acc.pth
    # model_path = './checkpoint/Eca_ASPx2_aux_wubn/model/netG_best_acc.pth'
    model = res_eca(layers=101, classes=1, pretrained=False, use_aux=True)
    model_path = './checkpoint/101/netG_best_acc.pth'
    # model = eca_ASP_1x(layers=50, classes=1, pretrained=False, use_aux=False)
    # model_path = './checkpoint/eca_ASP_1x/model/netG_best_acc.pth'

    model.load_state_dict(torch.load(model_path,map_location='cuda'))

    # checkpoint_finetune = torch.load(model_path,map_location='cuda')
    # model_dict = model.state_dict()
    #
    # pretrained_dict = {k: v for k, v in checkpoint_finetune.items() if k in model_dict}
    # model_dict.update(pretrained_dict)
    # model.load_state_dict(model_dict)

    model = model.cuda()
    #model.cuda()
    model.eval()
    gtList = os.listdir(gt_path)
    try:
        os.makedirs(predict_path)
    except OSError:
        pass
    try:
        os.makedirs(pathout)
    except OSError:
        pass
    #---------------------------------------------------------------#



    transform1 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.31701732, 0.32337377, 0.28925751],
                             std=[0.17323045, 0.16700189, 0.16922423])  # 标准化至[-1,1]
    ]
        )

    start = time.time()
    for i in tqdm(range(num)):
        #test_image = Image.open(val_path+ str(i) + '.png').convert('RGB')
        test_image = Image.open(val_path + valList[i]).convert('RGB')
        img = transform1(test_image)
        img = img.unsqueeze(0)
        img = img.cuda()

        label_image,a = model(img)

        label_image = label_image.squeeze(0)
        label_image[label_image >= yuzhi] = 1
        label_image[label_image < yuzhi] = 0
        label_image = label_image.cpu()
        a = transforms.ToPILImage()(label_image)
        a.save(predict_path+ valList[i])
        #print('转换第%d张图' % (i))
        i = i + 1

    end = time.time()
    print('Program processed ', end - start, 's, ', (end - start) / 60, 'min ')

def pinjie():
    if pinjie:
        gtList = os.listdir(gt_path)
        num_pic = len(gtList)
        i = 0  # 图片起始数字
        if use_overlap:
            num_pin = (1536 // h * 2 - 1)  # 横纵数量
        else:
            num_pin = 1536 // h  # 横纵数量
        if  use_overlap:
            for l in tqdm(range(num_pic)):
                for j in range(num_pin):     # j=纵行
                    for k in range(num_pin): # k=横行
                        # 余数为0,代表图片从0读取
                        if i % num_pin == 0:
                            path = predict_path + str(i) + ".tif"
                            img_heng = cv2.imread(path, 0)  # 0是按灰度模式读入
                            if j==0 and j!= (num_pin - 1):
                                img_heng = img_heng[0: (h-h//4), 0: (w-w//4)] #左上角
                            if j== (num_pin - 1):
                                img_heng = img_heng[ h // 4: h ,  w //4: w]  # 左下角
                            if j != 0 and j != (num_pin - 1) :
                                img_heng = img_heng[h // 4: (h - h // 4), 0: (w - w // 4)]  # 左一侧

                        else:
                            path = predict_path + str(i) + ".tif"
                            img_tmp = cv2.imread(path, 0)  # 0是按灰度模式读入
                            if j==0 and k != (num_pin - 1):
                                img_tmp = img_tmp[0: (h-h//4), w//4: (w-w//4)] #第一行
                            if j==0 and k== (num_pin - 1):
                                img_tmp = img_tmp[0: (h - h // 4), w // 4:w ]  # 右上角
                            if j==(num_pin - 1) and k!= (num_pin - 1):
                                img_tmp = img_tmp[(h//4): h, w//4: (w-w//4)] #最后一行
                            if j== (num_pin - 1) and k== (num_pin - 1):
                                img_tmp = img_tmp[(h//4): h , w // 4:w ]  # 右下角
                            if j!=0 and j!=(num_pin - 1) and k==(num_pin - 1):
                                img_tmp = img_tmp[h // 4: (h - h // 4), w // 4: w ] # 右一侧
                            if j != 0 and j != (num_pin - 1) and k != (num_pin - 1) :
                                img_tmp = img_tmp[h // 4: (h - h // 4), w // 4: (w-w//4)] #中间


                            img_heng = np.hstack((img_heng, img_tmp))  # 横行拼接
                        i = i + 1

                    if j == 0:
                        img_out = img_heng
                    else:
                        img_out = np.vstack((img_out, img_heng))  # 纵行拼接
                out_path = pathout + gtList[l]
                img_out = img_out[18:1518, 18:1518]  # 合成后切割
                cv2.imwrite(out_path, img_out)  # 输出图片
        else:
            for l in tqdm(range(num_pic)):
                for j in range(num_pin):
                    for k in range(num_pin):
                        # 余数为0,代表图片从0读取
                        if i % num_pin == 0:
                            path = predict_path + str(i) + ".tif"
                            img_heng = cv2.imread(path,0) #0是按灰度模式读入
                        else:
                            path = predict_path + str(i) + ".tif"
                            img_tmp = cv2.imread(path,0) #0是按灰度模式读入
                            img_heng = np.hstack((img_heng, img_tmp)) #横行拼接
                        i = i + 1

                    if j==0:
                        img_out = img_heng
                    else:
                        img_out = np.vstack((img_out, img_heng)) #纵行拼接
                out_path = pathout + gtList[l]
                img_out = img_out[ 18:1518, 18:1518  ] #合成后切割
                cv2.imwrite(out_path, img_out)  # 输出图片

if __name__ == '__main__':
    yuce()
    pinjie()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值