图片读取
from PIL import Image
img_org = Image.open(图片path).convert('RGB')
import cv2
label_mask = cv2.imread(图片path) # 默认以BGR形式读取图片
label_mask = cv2.cvtColor(label_mask, cv2.COLOR_BGR2GRAY) # 图片转成灰度图
图片变换
在PIL格式下才能调用torchvision.transforms进行图像变换
from torchvision import transforms
to_pil = transforms.ToPILImage()
prediction = to_pil(prediction)
图片保存
字体设置:中文
# 默认设置,后期画图可在画图函数内修改
config = {
"font.family": "serif", # 使用衬线体
# "font.serif": ["SIMSUN"], # 全局默认使用衬线宋体
"font.size": 12, # 五号,10.5磅
"axes.unicode_minus": False,
"mathtext.fontset": "stix", # 设置 LaTeX 字体,stix 近似于 Times 字体
}
rcParams.update(config)
SimSun = FontProperties(fname='/home/ges/anaconda3/envs/gesnew/lib/python3.8/site-packages/matplotlib/mpl-data/fonts/ttf/SIMSUN.ttc') # 字体库文件
画图及保存方法
#cv2 保存图片
cv2.imwrite(图片path, numpy类型的数组) # numpy类型的数组像素值的范围为[0, 255],必须是三通道图,分别为BGR通道
# 如果不是三通道,则将通道复制成三通道后保存:
plt.imsave(os.path.join(prediction_save_path, img_name + '_prediction_mean.png'), np.stack((prediction_array_mean, prediction_array_mean, prediction_array_mean), 2))
# plt 画多个子图并保存方法1
plt.figure()
plt.subplot(2,2,1)
plt.bar(data1, data1, width=0.05, color ='red', label="Calibration gap", hatch = '/', edgecolor='red', alpha = 0.3)
plt.bar(data, mean_bin_acc_list, width=0.05, color ='blue', label="outputs", alpha = 0.7)
plt.text(0.2,0.4,'ECE: '+'%.4f' %ECE1, fontsize=16, verticalalignment="top",horizontalalignment = "right", fontproperties=Times)
# plt.imshow(image_show_org)
plt.axis('off')
plt.legend(loc='upper left')
plt.xticks(fontproperties=Times, fontsize=14)
plt.yticks(fontproperties=Times, fontsize=14)
plt.xlabel("置信度", fontproperties=SimSun, fontsize=20)
plt.ylabel("准确度", fontproperties=SimSun, fontsize=20)
plt.title("单次预测", fontproperties=SimSun, fontsize=20)
plt.tight_layout()
plt.show()
plt.savefig(路径),dpi=500)
# plt 画多个子图方法2
fig, axes = plt.subplots(nrows=1, ncols=3, constrained_layout=True,sharey=True,figsize=(15,5))
axes0 = axes[0]
axes0.plot(data_new, accuracy_list_MI)
axes0.set_title("互信息", fontproperties=SimSun,fontsize=24)
axes0.tick_params(labelsize=20) # 设置横纵坐标刻度字体大小
axes0.grid(True) #显示网格
......
fig.tight_layout()
plt.show()
plt.savefig(referral_acc_path, dpi=500)