Python matplotlib画损失函数和精确度对比图并在原图上局部放大

1、调用库

从训练过程中导出训练集和验证集的Loss 和 Acc数据,使用CVS格式存储。因此,读取CVS数据过程中使用Pandas库。另外,需要将读取的数据转化成.npy格式,因此需要用到Numpy库。画图需要使用matplotlib

1.2、cvs数据展示

 2、局部放大函数编写

# load model
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams

def zoom_enlarge_1(axes, enlarge_position, x1, x2, y1, y2):
    xmin, xmax = x1, x2  # define the coordinates of the magnification section
    ymin, ymax = y1, y2  

    ax_zoom = plt.axes(enlarge_position)  # enlarge position means the subfigure position in the figure
    plt.plot(d2[:,0],d2[:,2], color='tomato', linewidth =2)
    plt.plot(d2[:,0],d3[:,2], color='deepskyblue', linewidth =2)
    plt.plot(d2[:,0],d4[:,2], color='springgreen', linewidth =2)
    plt.plot(d2[:,0],d5[:,2], color='deeppink', linewidth =2) # above plot the selected curves
    ax_zoom.set_xlim([xmin, xmax])  
    ax_zoom.set_ylim([ymin, ymax])  # above set the subfigure axis
#     ax_zoom.set_xlabel('X')  
#     ax_zoom.set_ylabel('Y')  
#     ax_zoom.set_title('Zoomed In')  

    rect = plt.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, fill=False, edgecolor='r', linewidth=2)  
    axes.add_patch(rect)  # show the selected part of the figure with highted line in red color
    
    
def zoom_enlarge_2(axes, enlarge_position, x1, x2, y1, y2):
    xmin, xmax = x1, x2  
    ymin, ymax = y1, y2  

    ax_zoom = plt.axes(enlarge_position)  
    plt.plot(d2[:,0],d2[:,1], color='tomato', linewidth =2)
    plt.plot(d2[:,0],d3[:,1], color='deepskyblue', linewidth =2)
    plt.plot(d2[:,0],d4[:,1], color='springgreen', linewidth =2)
    plt.plot(d2[:,0],d5[:,1], color='deeppink', linewidth =2)
    ax_zoom.set_xlim([xmin, xmax])  
    ax_zoom.set_ylim([ymin, ymax])  
#     ax_zoom.set_xlabel('X')  
#     ax_zoom.set_ylabel('Y')  
#     ax_zoom.set_title('Zoomed In')  

    rect = plt.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, fill=False, edgecolor='r', linewidth=2)  
    axes.add_patch(rect)  

 3、画原始曲线图

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
# plt.rcParams['font.sans-serif'] = 'simhei'
# plt.rcParams['axes.unicode_minus']=False
 
d2 = pd.read_csv("./logs/multi_scale_block_2.csv")
d2 = np.array(d2)
print(d2.shape)

d3 = pd.read_csv("./logs/multi_scale_block_3.csv")
d3 = np.array(d3)
print(d3.shape)

d4 = pd.read_csv("./logs/multi_scale_block_4.csv")
d4 = np.array(d4)
print(d4.shape)

d5 = pd.read_csv("./logs/multi_scale_block_5.csv")
d5 = np.array(d5)
print(d5.shape) # above load curves with pandas and convert data to numpy format


config = {
    "font.family":'serif',
    "font.size": 18,
    "mathtext.fontset":'stix',
    "font.serif": ['Times new roman'],
}
rcParams.update(config) # set the text style



plt.figure(figsize=(14,6))
ax1 = plt.subplot(121)
plt.plot(d2[:,0],d2[:,2], color='tomato', linewidth =2)
plt.plot(d2[:,0],d3[:,2], color='deepskyblue', linewidth =2)
plt.plot(d2[:,0],d4[:,2], color='springgreen', linewidth =2)
plt.plot(d2[:,0],d5[:,2], color='deeppink', linewidth =2)# above show curves
plt.title('Train loss',fontsize=20, fontweight='bold')
plt.ylabel('Loss',fontsize=18)
plt.xlabel('Epoch',fontsize=18)
plt.legend(['block = 2', 'block = 3', 'block = 4', 'block = 5'], loc='upper right',fontsize=18)
plt.tick_params(axis='both', which='major', labelsize=18)
plt.tick_params(axis='both', which='minor', labelsize=18)
plt.yticks(fontproperties='Times New Roman', size=18)
plt.xticks(fontproperties='Times New Roman', size=18)
zoom_enlarge_1(ax1, (0.25, 0.3, 0.2, 0.2), 3, 20, 0.00, 0.015) # use the enlarge function

bwith = 1.5 # define the with of the figure borders
TK = plt.gca()
TK.spines['bottom'].set_linewidth(bwith)
TK.spines['left'].set_linewidth(bwith)
TK.spines['top'].set_linewidth(bwith)
TK.spines['right'].set_linewidth(bwith)

# summarize history for loss
# plt.figure(figsize=(7,6))
ax2 = plt.subplot(122)
plt.plot(d2[:,0],d2[:,1], color='tomato', linewidth =2)
plt.plot(d2[:,0],d3[:,1], color='deepskyblue', linewidth =2)
plt.plot(d2[:,0],d4[:,1], color='springgreen', linewidth =2)
plt.plot(d2[:,0],d5[:,1], color='deeppink', linewidth =2)
plt.title('Train acc',fontsize=20, fontweight='bold')
plt.ylabel('Acc',fontsize=18)
plt.xlabel('Epoch',fontsize=18)
plt.legend(['block = 2', 'block = 3', 'block = 4', 'block = 5'], loc='lower right',fontsize=18)
plt.tick_params(axis='both', which='major', labelsize=18)
plt.tick_params(axis='both', which='minor', labelsize=18)
plt.yticks(fontproperties='Times New Roman', size=18)
plt.xticks(fontproperties='Times New Roman', size=18)
zoom_enlarge_2(ax2, (0.65, 0.5, 0.2, 0.2), 3, 20, 0.95, 1.00)


bwith = 1.5 
TK = plt.gca()
TK.spines['bottom'].set_linewidth(bwith)
TK.spines['left'].set_linewidth(bwith)
TK.spines['top'].set_linewidth(bwith)
TK.spines['right'].set_linewidth(bwith)

plt.savefig('loss_different_block.png', dpi=600, bbox_inches='tight')# save figure
plt.savefig('loss_different_block.eps', dpi=600, bbox_inches='tight')
plt.show()

3、效果图

 

  • 8
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

某崔同学

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值