预测代码:
from x1 import UNet
import torch
import numpy as np
import cv2
# 标准化
def standardize(data):
# 计算均值
mean = data.mean()
# 计算标准差
std = np.std(data)
# 计算结果
standardized = (data - mean) / std
return standardized
# 归一化
def normalize(data):
# 计算最大值和最小值
max_val = data.max()
min_val = data.min()
normalized = (data - min_val) / (max_val - min_val)
return normalized
# 实例化
model = UNet().to('cuda:0')
# 恢复权重
model.load_state_dict(torch.load('saved_model/unet_course_best.pt'))
model.eval()
img_file = "dui/img/f_y022f_147.jpg.npy"
img_data_crop = np.load(img_file)
# 标准化归一化
std = standardize(img_data_crop)
normalized = normalize(std)
# 处理输入
input = torch.tensor([[normalized]]).to('cuda:0', dtype=torch.float32)
# 推理
y_pred = model(input)
# 获取对应的mask(筛选阈值0.5)
img_ = normalize(y_pred.detach().cpu().numpy()[0][0])
print(img_.max(), img_.min())
img_1 = (img_ * 255.0).astype(np.uint8)
print(img_)
mask_data = (img_ > 0.95)
img = img_1 * mask_data
# cv2.imwrite('1.png', img)
ret, mask = cv2.threshold(img, 175, 255, cv2.THRESH_BINARY)
# mask_inv = cv2.bitwise_not(mask) # 非运算,mask取反
def cv_show(neme, img):
# cv2.namedWindow(neme, cv2.WINDOW_NORMAL)
cv2.imshow(neme, img)
cv2.waitKey(0)
cv2.destroyAllWindows()
cv_show('neme', mask)