P
r
e
d
i
c
t
(
生
成
图
像
)
Predict(生成图像)
Predict(生成图像)
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from PIL import Image
from major_dataset import LoadDataset
import cv2
import major_config
Load_test = LoadDataset([major_config.test_image, major_config.test_label], major_config.crop_size)
test_data = DataLoader(Load_test, batch_size=1)
net = major_config.model
net.eval()
net.to(major_config.device)
net.load_state_dict(torch.load(major_config.path_predict_model))
color2class_table = pd.read_csv(major_config.path_color2class_table)
for i, sample in enumerate(test_data):
valImg = sample['img'].to(major_config.device)
out = net(valImg)
out = F.log_softmax(out, dim=1)
pre_label = out.max(1)[1].squeeze().cpu().data.numpy()
print(pre_label)
cv2.imwrite(str(i)+".png",pre_label)
img_show = Image.open(str(i)+".png")
img_show.show()