深度模型测试

import torch,gc
import sys
from torch.utils.data import DataLoader
from datasets_ import Drishti_IMG,Test_Drishti
from models_.net import UNet
from torchvision.utils import save_image
import os
from models_.myNet import MyNet
from Metrics import IOU_acc,Dice
from utils_.Util_methods import  AddInterpreter


""" 四个网络 """


if __name__ == '__main__':

    # 设置数据集
    data_root=os.path.join("/home/wzc/zlt_self/pOSAL-master","data")
    if os.path.exists(data_root):
        dataloader=DataLoader(dataset=Test_Drishti(data_root),batch_size=4,shuffle=False)
    else:
        print("{} 文件夹不存在".format(data_root))
        
    # model加载到GPU,然后加载权重参数,最后设置优化和损失函数
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model_=UNet().to(device=device)
    for index_path in range(4):
        weight_path=f'/home/wzc/zlt_self/pOSAL-master/LarryChow/draft/tempFiles/weights_/unet_{index_path}.pth'
        # weight_path='/home/wzc/zlt_self/pOSAL-master/LarryChow/draft/tempFiles/weights_/unet.pth'
        if os.path.exists(weight_path):
            model_.load_state_dict(torch.load(weight_path))
            print('successful load weight!')
            
        else:
            print("not successful load weight")
            sys.exit(0)
        
        # 设置损失函数和精确度
        dice_fun=Dice()
        acc_fun=IOU_acc()
        
        for i,(img_,label_) in enumerate(dataloader):
            if i==1:
                continue
            #释放内存
            gc.collect()
            torch.cuda.empty_cache()
            # 模型预测
            img_=img_.to(device)
            label_=label_.to(device)
            pred_=model_(img_)  #type(pred_)=<class 'torch.Tensor'>
            for x_ in range(4):
                for index_ in range(1):
                    # save_img_=torch.stack([img_[x_],label_[x_],pred_[x_]],dim=0)
                    save_image(pred_[x_],f"unet_{index_path}-{x_}.png")
                    #添加文字信息
                    # AddInterpreter(imgPath=f"/home/wzc/zlt_self/pOSAL-master/tempFiles/unet-test_result{i}-{x_}.png",savePath=f"/home/wzc/zlt_self/pOSAL-master/tempFiles/unet-test_result_{i}-{x_}.png")
                    # print("pred success...")
                    # # 损失函数和精确度
                    # test_dice=dice_fun(label_,pred_)
                    # test_acc=acc_fun(label_,pred_)
                    # print(f"***{i}-{x_}test_dice:{test_dice}   test_acc:{test_acc}\n")
                    print(f"{index_path}{i}-{x_}-{index_}OK")
            break
            

  • 6
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Larry Chow

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值