import os
import torch as t
import torch.nn as nn
from torchvision import transforms
import cv2
from torchvision.transforms.functional import to_tensor
from unet import Unet
import numpy as np
import torch.nn.functional as F
from fcn import FCNet1
def predict(pretreix, filelist):
#model = Unet(3, 2)
model = FCNet1(3, 2)
device = t.device(“cuda” if t.cuda.is_available() else “cpu”)
checkpoint = t.load("./checkpoints/weights_19.pth", map_location=device)
if isinstance(checkpoint, dict) and ‘state_dict’ in checkpoint.keys():
checkpoint = checkpoint[‘state_dict’]
if ‘module’ in list(checkpoint.keys())[0] and not isinstance(model, t.nn.DataParallel):
model = t.nn.DataParallel(model)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
with t.no_grad():
for filename in filelist:
img = cv2.imread(pretreix + filename)
img = np.transpose(img, (2, 0, 1))
img = np.float32(img)
# img = np.transpose(img, (0, 1, 2))
img = t.from_numpy(img)
img = img.unsqueeze(0)
# input = img.unsqueeeze(0)
output = model(img.to(device))
outarray = output.cpu().numpy()
output1 = np.transpose(outarray, (0, 2, 3, 1))
imgout = output1[0, :, :, 0] * 255
cv2.imwrite("./testImg/" + filename, imgout)
print("hello")
if name == “main”:
test_path = “D:/jks_p80/p80White/p80White/sample/defect/val/1/”
listImg = os.listdir(test_path)
predict(test_path, listImg)