def plot_certified_accuracy_per_sigma_best_model(outfile: str, title: str, max_radius: float,
methods: List[Line]=None, label='Ours', methods_base: List[Line]=None, label_base='Baseline', radius_step: float = 0.01, upper_bounds=False, sigmas=[0.25]) -> None:
color = ['b', 'orange', 'g', 'r']
## patch for macos
fig=plt.figure()
ax=fig.add_subplot(1,1,1)
for it, sigma in enumerate(sigmas):
methods_sigma = [method for method in methods if '{:.2f}'.format(sigma) in method.quantity.data_file_path]
accuracies_cert_ours, radii = _get_accuracies_at_radii(methods_sigma, 0, max_radius, radius_step)
accuracies_cert_ours = np.nan_to_num(accuracies_cert_ours, -1)
ax.plot(radii, accuracies_cert_ours[accuracies_cert_ours[:,0].argmax(), :], color[it], label='{}|$\sigma = {:.2f}$'.format(label, sigma))
# plt.plot(radii, accuracies_cert_ours[accuracies_cert_ours[:,0].argmax(), :], color[it], label='{}|$\sigma = {:.2f}$'.format(label, sigma))
for it, sigma in enumerate(sigmas):
methods_sigma_base = [method for method in methods_base if '{:.2f}'.format(sigma) in method.quantity.data_file_path]
accuracies_cert_ours, radii = _get_accuracies_at_radii(methods_sigma_base, 0, max_radius, radius_step)
accuracies_cert_ours = np.nan_to_num(accuracies_cert_ours, -1)
ax.plot(radii, accuracies_cert_ours[accuracies_cert_ours[:,0].argmax(), :], color[it], dashes=[2, 2], label='{}|$\sigma = {:.2f}$'.format(label_base, sigma))
# plt.plot(radii, accuracies_cert_ours[accuracies_cert_ours[:,0].argmax(), :], color[it], dashes=[2, 2], label='{}|$\sigma = {:.2f}$'.format(label_base, sigma))
# plt.ylim((0, 1))
# plt.xlim((0, max_radius))
# plt.tick_params(labelsize=14)
# plt.xlabel("$\ell_2$ radius", fontsize=16)
# plt.ylabel("Certified Accuracy", fontsize=16)
# plt.gca().xaxis.set_major_locator(plt.MultipleLocator(0.5))
# plt.legend(loc='upper right', fontsize=16)
# plt.tight_layout()
# plt.savefig(outfile + ".pdf")
# plt.title(title, fontsize=20)
# plt.tight_layout()
# plt.savefig(outfile + ".png", dpi=300)
# plt.close()
ax.set_ylim(0, 1)
ax.set_xlim(0, max_radius)
ax.tick_params(labelsize=14)
ax.set_xlabel("$\ell_2$ radius", fontsize=16)
ax.set_ylabel("Certified Accuracy", fontsize=16)
ax.xaxis.set_major_locator(plt.MultipleLocator(0.5))
ax.legend(loc='upper right', fontsize=16)
plt.tight_layout()
fig.savefig(outfile + ".png", dpi=300)
plt.show()
# plt.close()
print("saved!", outfile)
【macos】matplotlib绘图
最新推荐文章于 2024-11-08 20:19:05 发布