2020-09-21

import os
import torch as t
import torch.nn as nn
from torchvision import transforms
import cv2
from torchvision.transforms.functional import to_tensor

from unet import Unet
import numpy as np
import torch.nn.functional as F
from fcn import FCNet1

def predict(pretreix, filelist):
#model = Unet(3, 2)
model = FCNet1(3, 2)
device = t.device(“cuda” if t.cuda.is_available() else “cpu”)
checkpoint = t.load("./checkpoints/weights_19.pth", map_location=device)
if isinstance(checkpoint, dict) and ‘state_dict’ in checkpoint.keys():
checkpoint = checkpoint[‘state_dict’]
if ‘module’ in list(checkpoint.keys())[0] and not isinstance(model, t.nn.DataParallel):
model = t.nn.DataParallel(model)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()

with t.no_grad():
    for filename in filelist:
        img = cv2.imread(pretreix + filename)
        img = np.transpose(img, (2, 0, 1))
        img = np.float32(img)
        # img = np.transpose(img, (0, 1, 2))
        img = t.from_numpy(img)
        img = img.unsqueeze(0)

        # input = img.unsqueeeze(0)
        output = model(img.to(device))

        outarray = output.cpu().numpy()
        output1 = np.transpose(outarray, (0, 2, 3, 1))
        imgout = output1[0, :, :, 0] * 255
        cv2.imwrite("./testImg/" + filename, imgout)
        print("hello")

if name == “main”:
test_path = “D:/jks_p80/p80White/p80White/sample/defect/val/1/”
listImg = os.listdir(test_path)
predict(test_path, listImg)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值