代码
import matplotlib.pyplot as plt
import numpy as np
def get_data(txt_path: str = '', epoch: int = 100, target: str = '', target_data_len: int = 5):
num_list = []
data = open(txt_path, encoding="utf-8")
str1 = data.read()
data.close()
for i in range(0, epoch):
index = str1.find(target)
num_list.append(float(str1[index+len(target):index+len(target)+target_data_len]))
str1 = str1.replace(target, 'xxxx', 1)
return num_list
plt.rcParams['font.size'] = 18
list_ACC1 = get_data("./everything_to_Matlab/test.txt", 51, target="ACC1:", target_data_len=11)
list_ACC2 = get_data("./everything_to_Matlab/test.txt", 51, target="test2:", target_data_len=11)
list_loss1 = get_data("./everything_to_Matlab/test.txt", 50, target="loss1:", target_data_len=11)
list_loss2 = get_data("./everything_to_Matlab/test.txt", 50, target="loss2:", target_data_len=11)
fig, ax1 = plt.subplots()
ax1.plot(list_ACC1, color = "#E18E6D", label = "lr_mul=1")
ax1.plot(list_ACC2, color = "#62B197", label = "lr_mul=0.5")
ax1.legend(loc='center right')
ax1.set_yticks([0.9995, 0.9943, 1.006])
ax1.set_yticklabels(["99.95%", "99.43%", "Accuracy"])
ax1.set_ylim(0.90, 1.006)
ax1.set_xlim(0, 50)
ax1.set_xlabel("epoch")
ax1.grid(axis='y')
ax2 = ax1.twinx()
ax2.plot(list_loss1, color = "#E18E6D")
ax2.plot(list_loss2, color = "#62B197")
ax2.set_yticks([0.0005025579, 0.0001039364, 0.0079685581])
ax2.set_yticklabels(["0.5", "0.1", "loss(e-3)"])
ax2.set_ylim(0.0001039364 , 0.0079685581)
ax2.set_xlim(0, 50)
ax2.set_xlabel("epoch")
ax2.grid(axis='y')
plt.show()
结果